diff options
Diffstat (limited to 'synapse')
33 files changed, 1174 insertions, 988 deletions
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 3bb5b3da37..ad68079eeb 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -23,12 +23,13 @@ class Ratelimiter(object): def __init__(self): self.message_counts = collections.OrderedDict() - def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count, update=True): - """Can the user send a message? + def can_do_action(self, key, time_now_s, rate_hz, burst_count, update=True): + """Can the entity (e.g. user or IP address) perform the action? Args: - user_id: The user sending a message. + key: The key we should use when rate limiting. Can be a user ID + (when sending events), an IP address, etc. time_now_s: The time now. - msg_rate_hz: The long term number of messages a user can send in a + rate_hz: The long term number of messages a user can send in a second. burst_count: How many messages the user can send before being limited. @@ -41,10 +42,10 @@ class Ratelimiter(object): """ self.prune_message_counts(time_now_s) message_count, time_start, _ignored = self.message_counts.get( - user_id, (0., time_now_s, None), + key, (0., time_now_s, None), ) time_delta = time_now_s - time_start - sent_count = message_count - time_delta * msg_rate_hz + sent_count = message_count - time_delta * rate_hz if sent_count < 0: allowed = True time_start = time_now_s @@ -56,13 +57,13 @@ class Ratelimiter(object): message_count += 1 if update: - self.message_counts[user_id] = ( - message_count, time_start, msg_rate_hz + self.message_counts[key] = ( + message_count, time_start, rate_hz ) - if msg_rate_hz > 0: + if rate_hz > 0: time_allowed = ( - time_start + (message_count - burst_count + 1) / msg_rate_hz + time_start + (message_count - burst_count + 1) / rate_hz ) if time_allowed < time_now_s: time_allowed = time_now_s @@ -72,12 +73,12 @@ class Ratelimiter(object): return allowed, time_allowed def prune_message_counts(self, time_now_s): - for user_id in list(self.message_counts.keys()): - message_count, time_start, msg_rate_hz = ( - self.message_counts[user_id] + for key in list(self.message_counts.keys()): + message_count, time_start, rate_hz = ( + self.message_counts[key] ) time_delta = time_now_s - time_start - if message_count - time_delta * msg_rate_hz > 0: + if message_count - time_delta * rate_hz > 0: break else: - del self.message_counts[user_id] + del self.message_counts[key] diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index 5070094cad..beaea64a61 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -33,9 +33,13 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.client_ips import SlavedClientIpStore +from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore +from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.directory import DirectoryStore from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.keys import SlavedKeyStore +from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore +from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore @@ -49,6 +53,7 @@ from synapse.rest.client.v1.room import ( RoomStateRestServlet, ) from synapse.rest.client.v2_alpha.account import ThreepidRestServlet +from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet from synapse.rest.client.v2_alpha.register import RegisterRestServlet from synapse.server import HomeServer from synapse.storage.engines import create_engine @@ -61,6 +66,10 @@ logger = logging.getLogger("synapse.app.client_reader") class ClientReaderSlavedStore( + SlavedDeviceInboxStore, + SlavedDeviceStore, + SlavedReceiptsStore, + SlavedPushRuleStore, SlavedAccountDataStore, SlavedEventStore, SlavedKeyStore, @@ -98,6 +107,8 @@ class ClientReaderServer(HomeServer): RegisterRestServlet(self).register(resource) LoginRestServlet(self).register(resource) ThreepidRestServlet(self).register(resource) + KeyQueryServlet(self).register(resource) + KeyChangesServlet(self).register(resource) resources.update({ "/_matrix/client/r0": resource, diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 5aec43b702..c4d3087fa4 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -180,9 +180,7 @@ class Config(object): Returns: str: the yaml config file """ - default_config = "# vim:ft=yaml\n" - - default_config += "\n\n".join( + default_config = "\n\n".join( dedent(conf) for conf in self.invoke_all( "default_config", @@ -297,19 +295,26 @@ class Config(object): "Must specify a server_name to a generate config for." " Pass -H server.name." ) + + config_str = obj.generate_config( + config_dir_path=config_dir_path, + data_dir_path=os.getcwd(), + server_name=server_name, + report_stats=(config_args.report_stats == "yes"), + generate_secrets=True, + ) + if not cls.path_exists(config_dir_path): os.makedirs(config_dir_path) with open(config_path, "w") as config_file: - config_str = obj.generate_config( - config_dir_path=config_dir_path, - data_dir_path=os.getcwd(), - server_name=server_name, - report_stats=(config_args.report_stats == "yes"), - generate_secrets=True, + config_file.write( + "# vim:ft=yaml\n\n" ) - config = yaml.load(config_str) - obj.invoke_all("generate_files", config) config_file.write(config_str) + + config = yaml.load(config_str) + obj.invoke_all("generate_files", config) + print( ( "A config file has been generated in %r for server name" diff --git a/synapse/config/database.py b/synapse/config/database.py index c8890147a6..63e9cb63f8 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -49,7 +49,8 @@ class DatabaseConfig(Config): def default_config(self, data_dir_path, **kwargs): database_path = os.path.join(data_dir_path, "homeserver.db") return """\ - # Database configuration + ## Database ## + database: # The database engine name name: "sqlite3" diff --git a/synapse/config/logger.py b/synapse/config/logger.py index f6940b65fd..464c28c2d9 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -81,7 +81,9 @@ class LoggingConfig(Config): def default_config(self, config_dir_path, server_name, **kwargs): log_config = os.path.join(config_dir_path, server_name + ".log.config") - return """ + return """\ + ## Logging ## + # A yaml python logging config file # log_config: "%(log_config)s" diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 54b71e6841..093042fdb9 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -27,6 +27,13 @@ class RatelimitConfig(Config): self.federation_rc_reject_limit = config["federation_rc_reject_limit"] self.federation_rc_concurrent = config["federation_rc_concurrent"] + self.rc_registration_requests_per_second = config.get( + "rc_registration_requests_per_second", 0.17, + ) + self.rc_registration_request_burst_count = config.get( + "rc_registration_request_burst_count", 3, + ) + def default_config(self, **kwargs): return """\ ## Ratelimiting ## @@ -62,4 +69,15 @@ class RatelimitConfig(Config): # single server # federation_rc_concurrent: 3 + + # Number of registration requests a client can send per second. + # Defaults to 1/minute (0.17). + # + #rc_registration_requests_per_second: 0.17 + + # Number of registration requests a client can send before being + # throttled. + # Defaults to 3. + # + #rc_registration_request_burst_count: 3.0 """ diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 2881482f96..d34dc9e456 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -64,6 +64,8 @@ class RegistrationConfig(Config): return """\ ## Registration ## + # Registration can be rate-limited using the parameters in the "Ratelimiting" + # section of this file. # Enable registration for new users. enable_registration: False diff --git a/synapse/config/server.py b/synapse/config/server.py index 4200f10da3..35a322fee0 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -260,9 +260,11 @@ class ServerConfig(Config): # This is used by remote servers to connect to this server, # e.g. matrix.org, localhost:8080, etc. # This is also the last part of your UserID. + # server_name: "%(server_name)s" # When running as a daemon, the file to store the pid in + # pid_file: %(pid_file)s # CPU affinity mask. Setting this restricts the CPUs on which the @@ -304,9 +306,11 @@ class ServerConfig(Config): # Set the soft limit on the number of file descriptors synapse can use # Zero is used to indicate synapse should set the soft limit to the # hard limit. + # soft_file_limit: 0 # Set to false to disable presence tracking on this homeserver. + # use_presence: true # The GC threshold parameters to pass to `gc.set_threshold`, if defined diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 6f5995735a..b7d0b25781 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -159,8 +159,12 @@ class FederationRemoteSendQueue(object): # stream. pass - def send_edu(self, destination, edu_type, content, key=None): + def build_and_send_edu(self, destination, edu_type, content, key=None): """As per TransactionQueue""" + if destination == self.server_name: + logger.info("Not sending EDU to ourselves") + return + pos = self._next_pos() edu = Edu( @@ -465,15 +469,11 @@ def process_rows_for_federation(transaction_queue, rows): for destination, edu_map in iteritems(buff.keyed_edus): for key, edu in edu_map.items(): - transaction_queue.send_edu( - edu.destination, edu.edu_type, edu.content, key=key, - ) + transaction_queue.send_edu(edu, key) for destination, edu_list in iteritems(buff.edus): for edu in edu_list: - transaction_queue.send_edu( - edu.destination, edu.edu_type, edu.content, key=None, - ) + transaction_queue.send_edu(edu, None) for destination in buff.device_destinations: transaction_queue.send_device_messages(destination) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 30941f5ad6..e5e42c647d 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -361,7 +361,19 @@ class TransactionQueue(object): self._attempt_new_transaction(destination) - def send_edu(self, destination, edu_type, content, key=None): + def build_and_send_edu(self, destination, edu_type, content, key=None): + """Construct an Edu object, and queue it for sending + + Args: + destination (str): name of server to send to + edu_type (str): type of EDU to send + content (dict): content of EDU + key (Any|None): clobbering key for this edu + """ + if destination == self.server_name: + logger.info("Not sending EDU to ourselves") + return + edu = Edu( origin=self.server_name, destination=destination, @@ -369,18 +381,23 @@ class TransactionQueue(object): content=content, ) - if destination == self.server_name: - logger.info("Not sending EDU to ourselves") - return + self.send_edu(edu, key) + + def send_edu(self, edu, key): + """Queue an EDU for sending + Args: + edu (Edu): edu to send + key (Any|None): clobbering key for this edu + """ if key: self.pending_edus_keyed_by_dest.setdefault( - destination, {} + edu.destination, {} )[(edu.edu_type, key)] = edu else: - self.pending_edus_by_dest.setdefault(destination, []).append(edu) + self.pending_edus_by_dest.setdefault(edu.destination, []).append(edu) - self._attempt_new_transaction(destination) + self._attempt_new_transaction(edu.destination) def send_device_messages(self, destination): if destination == self.server_name: diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index ebb81be377..96d680a5ad 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -759,7 +759,7 @@ class FederationVersionServlet(BaseFederationServlet): class FederationGroupsProfileServlet(BaseFederationServlet): """Get/set the basic profile of a group on behalf of a user """ - PATH = "/groups/(?P<group_id>[^/]*)/profile$" + PATH = "/groups/(?P<group_id>[^/]*)/profile" @defer.inlineCallbacks def on_GET(self, origin, content, query, group_id): @@ -787,7 +787,7 @@ class FederationGroupsProfileServlet(BaseFederationServlet): class FederationGroupsSummaryServlet(BaseFederationServlet): - PATH = "/groups/(?P<group_id>[^/]*)/summary$" + PATH = "/groups/(?P<group_id>[^/]*)/summary" @defer.inlineCallbacks def on_GET(self, origin, content, query, group_id): @@ -805,7 +805,7 @@ class FederationGroupsSummaryServlet(BaseFederationServlet): class FederationGroupsRoomsServlet(BaseFederationServlet): """Get the rooms in a group on behalf of a user """ - PATH = "/groups/(?P<group_id>[^/]*)/rooms$" + PATH = "/groups/(?P<group_id>[^/]*)/rooms" @defer.inlineCallbacks def on_GET(self, origin, content, query, group_id): @@ -823,7 +823,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet): class FederationGroupsAddRoomsServlet(BaseFederationServlet): """Add/remove room from group """ - PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)$" + PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)" @defer.inlineCallbacks def on_POST(self, origin, content, query, group_id, room_id): @@ -855,7 +855,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): """ PATH = ( "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)" - "/config/(?P<config_key>[^/]*)$" + "/config/(?P<config_key>[^/]*)" ) @defer.inlineCallbacks @@ -874,7 +874,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): class FederationGroupsUsersServlet(BaseFederationServlet): """Get the users in a group on behalf of a user """ - PATH = "/groups/(?P<group_id>[^/]*)/users$" + PATH = "/groups/(?P<group_id>[^/]*)/users" @defer.inlineCallbacks def on_GET(self, origin, content, query, group_id): @@ -892,7 +892,7 @@ class FederationGroupsUsersServlet(BaseFederationServlet): class FederationGroupsInvitedUsersServlet(BaseFederationServlet): """Get the users that have been invited to a group """ - PATH = "/groups/(?P<group_id>[^/]*)/invited_users$" + PATH = "/groups/(?P<group_id>[^/]*)/invited_users" @defer.inlineCallbacks def on_GET(self, origin, content, query, group_id): @@ -910,7 +910,7 @@ class FederationGroupsInvitedUsersServlet(BaseFederationServlet): class FederationGroupsInviteServlet(BaseFederationServlet): """Ask a group server to invite someone to the group """ - PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$" + PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite" @defer.inlineCallbacks def on_POST(self, origin, content, query, group_id, user_id): @@ -928,7 +928,7 @@ class FederationGroupsInviteServlet(BaseFederationServlet): class FederationGroupsAcceptInviteServlet(BaseFederationServlet): """Accept an invitation from the group server """ - PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite$" + PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite" @defer.inlineCallbacks def on_POST(self, origin, content, query, group_id, user_id): @@ -945,7 +945,7 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet): class FederationGroupsJoinServlet(BaseFederationServlet): """Attempt to join a group """ - PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join$" + PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join" @defer.inlineCallbacks def on_POST(self, origin, content, query, group_id, user_id): @@ -962,7 +962,7 @@ class FederationGroupsJoinServlet(BaseFederationServlet): class FederationGroupsRemoveUserServlet(BaseFederationServlet): """Leave or kick a user from the group """ - PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$" + PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove" @defer.inlineCallbacks def on_POST(self, origin, content, query, group_id, user_id): @@ -980,7 +980,7 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet): class FederationGroupsLocalInviteServlet(BaseFederationServlet): """A group server has invited a local user """ - PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$" + PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite" @defer.inlineCallbacks def on_POST(self, origin, content, query, group_id, user_id): @@ -997,7 +997,7 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet): class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): """A group server has removed a local user """ - PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$" + PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove" @defer.inlineCallbacks def on_POST(self, origin, content, query, group_id, user_id): @@ -1014,7 +1014,7 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): """A group or user's server renews their attestation """ - PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)$" + PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)" @defer.inlineCallbacks def on_POST(self, origin, content, query, group_id, user_id): @@ -1037,7 +1037,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): PATH = ( "/groups/(?P<group_id>[^/]*)/summary" "(/categories/(?P<category_id>[^/]+))?" - "/rooms/(?P<room_id>[^/]*)$" + "/rooms/(?P<room_id>[^/]*)" ) @defer.inlineCallbacks @@ -1080,7 +1080,7 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet): """Get all categories for a group """ PATH = ( - "/groups/(?P<group_id>[^/]*)/categories/$" + "/groups/(?P<group_id>[^/]*)/categories/" ) @defer.inlineCallbacks @@ -1100,7 +1100,7 @@ class FederationGroupsCategoryServlet(BaseFederationServlet): """Add/remove/get a category in a group """ PATH = ( - "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$" + "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)" ) @defer.inlineCallbacks @@ -1150,7 +1150,7 @@ class FederationGroupsRolesServlet(BaseFederationServlet): """Get roles in a group """ PATH = ( - "/groups/(?P<group_id>[^/]*)/roles/$" + "/groups/(?P<group_id>[^/]*)/roles/" ) @defer.inlineCallbacks @@ -1170,7 +1170,7 @@ class FederationGroupsRoleServlet(BaseFederationServlet): """Add/remove/get a role in a group """ PATH = ( - "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$" + "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)" ) @defer.inlineCallbacks @@ -1226,7 +1226,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet): PATH = ( "/groups/(?P<group_id>[^/]*)/summary" "(/roles/(?P<role_id>[^/]+))?" - "/users/(?P<user_id>[^/]*)$" + "/users/(?P<user_id>[^/]*)" ) @defer.inlineCallbacks @@ -1269,7 +1269,7 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet): """Get roles in a group """ PATH = ( - "/get_groups_publicised$" + "/get_groups_publicised" ) @defer.inlineCallbacks @@ -1284,7 +1284,7 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet): class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet): """Sets whether a group is joinable without an invite or knock """ - PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy$" + PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy" @defer.inlineCallbacks def on_PUT(self, origin, content, query, group_id): diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 594754cfd8..d8d86d6ff3 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -93,9 +93,9 @@ class BaseHandler(object): messages_per_second = self.hs.config.rc_messages_per_second burst_count = self.hs.config.rc_message_burst_count - allowed, time_allowed = self.ratelimiter.send_message( + allowed, time_allowed = self.ratelimiter.can_do_action( user_id, time_now, - msg_rate_hz=messages_per_second, + rate_hz=messages_per_second, burst_count=burst_count, update=update, ) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index c708c35d4d..c09a7c6280 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -37,13 +37,185 @@ from ._base import BaseHandler logger = logging.getLogger(__name__) -class DeviceHandler(BaseHandler): +class DeviceWorkerHandler(BaseHandler): def __init__(self, hs): - super(DeviceHandler, self).__init__(hs) + super(DeviceWorkerHandler, self).__init__(hs) self.hs = hs self.state = hs.get_state_handler() self._auth_handler = hs.get_auth_handler() + + @defer.inlineCallbacks + def get_devices_by_user(self, user_id): + """ + Retrieve the given user's devices + + Args: + user_id (str): + Returns: + defer.Deferred: list[dict[str, X]]: info on each device + """ + + device_map = yield self.store.get_devices_by_user(user_id) + + ips = yield self.store.get_last_client_ip_by_device( + user_id, device_id=None + ) + + devices = list(device_map.values()) + for device in devices: + _update_device_from_client_ips(device, ips) + + defer.returnValue(devices) + + @defer.inlineCallbacks + def get_device(self, user_id, device_id): + """ Retrieve the given device + + Args: + user_id (str): + device_id (str): + + Returns: + defer.Deferred: dict[str, X]: info on the device + Raises: + errors.NotFoundError: if the device was not found + """ + try: + device = yield self.store.get_device(user_id, device_id) + except errors.StoreError: + raise errors.NotFoundError + ips = yield self.store.get_last_client_ip_by_device( + user_id, device_id, + ) + _update_device_from_client_ips(device, ips) + defer.returnValue(device) + + @measure_func("device.get_user_ids_changed") + @defer.inlineCallbacks + def get_user_ids_changed(self, user_id, from_token): + """Get list of users that have had the devices updated, or have newly + joined a room, that `user_id` may be interested in. + + Args: + user_id (str) + from_token (StreamToken) + """ + now_room_key = yield self.store.get_room_events_max_id() + + room_ids = yield self.store.get_rooms_for_user(user_id) + + # First we check if any devices have changed + changed = yield self.store.get_user_whose_devices_changed( + from_token.device_list_key + ) + + # Then work out if any users have since joined + rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key) + + member_events = yield self.store.get_membership_changes_for_user( + user_id, from_token.room_key, now_room_key, + ) + rooms_changed.update(event.room_id for event in member_events) + + stream_ordering = RoomStreamToken.parse_stream_token( + from_token.room_key + ).stream + + possibly_changed = set(changed) + possibly_left = set() + for room_id in rooms_changed: + current_state_ids = yield self.store.get_current_state_ids(room_id) + + # The user may have left the room + # TODO: Check if they actually did or if we were just invited. + if room_id not in room_ids: + for key, event_id in iteritems(current_state_ids): + etype, state_key = key + if etype != EventTypes.Member: + continue + possibly_left.add(state_key) + continue + + # Fetch the current state at the time. + try: + event_ids = yield self.store.get_forward_extremeties_for_room( + room_id, stream_ordering=stream_ordering + ) + except errors.StoreError: + # we have purged the stream_ordering index since the stream + # ordering: treat it the same as a new room + event_ids = [] + + # special-case for an empty prev state: include all members + # in the changed list + if not event_ids: + for key, event_id in iteritems(current_state_ids): + etype, state_key = key + if etype != EventTypes.Member: + continue + possibly_changed.add(state_key) + continue + + current_member_id = current_state_ids.get((EventTypes.Member, user_id)) + if not current_member_id: + continue + + # mapping from event_id -> state_dict + prev_state_ids = yield self.store.get_state_ids_for_events(event_ids) + + # Check if we've joined the room? If so we just blindly add all the users to + # the "possibly changed" users. + for state_dict in itervalues(prev_state_ids): + member_event = state_dict.get((EventTypes.Member, user_id), None) + if not member_event or member_event != current_member_id: + for key, event_id in iteritems(current_state_ids): + etype, state_key = key + if etype != EventTypes.Member: + continue + possibly_changed.add(state_key) + break + + # If there has been any change in membership, include them in the + # possibly changed list. We'll check if they are joined below, + # and we're not toooo worried about spuriously adding users. + for key, event_id in iteritems(current_state_ids): + etype, state_key = key + if etype != EventTypes.Member: + continue + + # check if this member has changed since any of the extremities + # at the stream_ordering, and add them to the list if so. + for state_dict in itervalues(prev_state_ids): + prev_event_id = state_dict.get(key, None) + if not prev_event_id or prev_event_id != event_id: + if state_key != user_id: + possibly_changed.add(state_key) + break + + if possibly_changed or possibly_left: + users_who_share_room = yield self.store.get_users_who_share_room_with_user( + user_id + ) + + # Take the intersection of the users whose devices may have changed + # and those that actually still share a room with the user + possibly_joined = possibly_changed & users_who_share_room + possibly_left = (possibly_changed | possibly_left) - users_who_share_room + else: + possibly_joined = [] + possibly_left = [] + + defer.returnValue({ + "changed": list(possibly_joined), + "left": list(possibly_left), + }) + + +class DeviceHandler(DeviceWorkerHandler): + def __init__(self, hs): + super(DeviceHandler, self).__init__(hs) + self.federation_sender = hs.get_federation_sender() self._edu_updater = DeviceListEduUpdater(hs, self) @@ -104,52 +276,6 @@ class DeviceHandler(BaseHandler): raise errors.StoreError(500, "Couldn't generate a device ID.") @defer.inlineCallbacks - def get_devices_by_user(self, user_id): - """ - Retrieve the given user's devices - - Args: - user_id (str): - Returns: - defer.Deferred: list[dict[str, X]]: info on each device - """ - - device_map = yield self.store.get_devices_by_user(user_id) - - ips = yield self.store.get_last_client_ip_by_device( - user_id, device_id=None - ) - - devices = list(device_map.values()) - for device in devices: - _update_device_from_client_ips(device, ips) - - defer.returnValue(devices) - - @defer.inlineCallbacks - def get_device(self, user_id, device_id): - """ Retrieve the given device - - Args: - user_id (str): - device_id (str): - - Returns: - defer.Deferred: dict[str, X]: info on the device - Raises: - errors.NotFoundError: if the device was not found - """ - try: - device = yield self.store.get_device(user_id, device_id) - except errors.StoreError: - raise errors.NotFoundError - ips = yield self.store.get_last_client_ip_by_device( - user_id, device_id, - ) - _update_device_from_client_ips(device, ips) - defer.returnValue(device) - - @defer.inlineCallbacks def delete_device(self, user_id, device_id): """ Delete the given device @@ -287,126 +413,6 @@ class DeviceHandler(BaseHandler): for host in hosts: self.federation_sender.send_device_messages(host) - @measure_func("device.get_user_ids_changed") - @defer.inlineCallbacks - def get_user_ids_changed(self, user_id, from_token): - """Get list of users that have had the devices updated, or have newly - joined a room, that `user_id` may be interested in. - - Args: - user_id (str) - from_token (StreamToken) - """ - now_token = yield self.hs.get_event_sources().get_current_token() - - room_ids = yield self.store.get_rooms_for_user(user_id) - - # First we check if any devices have changed - changed = yield self.store.get_user_whose_devices_changed( - from_token.device_list_key - ) - - # Then work out if any users have since joined - rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key) - - member_events = yield self.store.get_membership_changes_for_user( - user_id, from_token.room_key, now_token.room_key - ) - rooms_changed.update(event.room_id for event in member_events) - - stream_ordering = RoomStreamToken.parse_stream_token( - from_token.room_key - ).stream - - possibly_changed = set(changed) - possibly_left = set() - for room_id in rooms_changed: - current_state_ids = yield self.store.get_current_state_ids(room_id) - - # The user may have left the room - # TODO: Check if they actually did or if we were just invited. - if room_id not in room_ids: - for key, event_id in iteritems(current_state_ids): - etype, state_key = key - if etype != EventTypes.Member: - continue - possibly_left.add(state_key) - continue - - # Fetch the current state at the time. - try: - event_ids = yield self.store.get_forward_extremeties_for_room( - room_id, stream_ordering=stream_ordering - ) - except errors.StoreError: - # we have purged the stream_ordering index since the stream - # ordering: treat it the same as a new room - event_ids = [] - - # special-case for an empty prev state: include all members - # in the changed list - if not event_ids: - for key, event_id in iteritems(current_state_ids): - etype, state_key = key - if etype != EventTypes.Member: - continue - possibly_changed.add(state_key) - continue - - current_member_id = current_state_ids.get((EventTypes.Member, user_id)) - if not current_member_id: - continue - - # mapping from event_id -> state_dict - prev_state_ids = yield self.store.get_state_ids_for_events(event_ids) - - # Check if we've joined the room? If so we just blindly add all the users to - # the "possibly changed" users. - for state_dict in itervalues(prev_state_ids): - member_event = state_dict.get((EventTypes.Member, user_id), None) - if not member_event or member_event != current_member_id: - for key, event_id in iteritems(current_state_ids): - etype, state_key = key - if etype != EventTypes.Member: - continue - possibly_changed.add(state_key) - break - - # If there has been any change in membership, include them in the - # possibly changed list. We'll check if they are joined below, - # and we're not toooo worried about spuriously adding users. - for key, event_id in iteritems(current_state_ids): - etype, state_key = key - if etype != EventTypes.Member: - continue - - # check if this member has changed since any of the extremities - # at the stream_ordering, and add them to the list if so. - for state_dict in itervalues(prev_state_ids): - prev_event_id = state_dict.get(key, None) - if not prev_event_id or prev_event_id != event_id: - if state_key != user_id: - possibly_changed.add(state_key) - break - - if possibly_changed or possibly_left: - users_who_share_room = yield self.store.get_users_who_share_room_with_user( - user_id - ) - - # Take the intersection of the users whose devices may have changed - # and those that actually still share a room with the user - possibly_joined = possibly_changed & users_who_share_room - possibly_left = (possibly_changed | possibly_left) - users_who_share_room - else: - possibly_joined = [] - possibly_left = [] - - defer.returnValue({ - "changed": list(possibly_joined), - "left": list(possibly_left), - }) - @defer.inlineCallbacks def on_federation_query_user_devices(self, user_id): stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index f80486102a..72b63d64d0 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -858,6 +858,52 @@ class FederationHandler(BaseHandler): logger.debug("Not backfilling as no extremeties found.") return + # We only want to paginate if we can actually see the events we'll get, + # as otherwise we'll just spend a lot of resources to get redacted + # events. + # + # We do this by filtering all the backwards extremities and seeing if + # any remain. Given we don't have the extremity events themselves, we + # need to actually check the events that reference them. + # + # *Note*: the spec wants us to keep backfilling until we reach the start + # of the room in case we are allowed to see some of the history. However + # in practice that causes more issues than its worth, as a) its + # relatively rare for there to be any visible history and b) even when + # there is its often sufficiently long ago that clients would stop + # attempting to paginate before backfill reached the visible history. + # + # TODO: If we do do a backfill then we should filter the backwards + # extremities to only include those that point to visible portions of + # history. + # + # TODO: Correctly handle the case where we are allowed to see the + # forward event but not the backward extremity, e.g. in the case of + # initial join of the server where we are allowed to see the join + # event but not anything before it. This would require looking at the + # state *before* the event, ignoring the special casing certain event + # types have. + + forward_events = yield self.store.get_successor_events( + list(extremities), + ) + + extremities_events = yield self.store.get_events( + forward_events, + check_redacted=False, + get_prev_content=False, + ) + + # We set `check_history_visibility_only` as we might otherwise get false + # positives from users having been erased. + filtered_extremities = yield filter_events_for_server( + self.store, self.server_name, list(extremities_events.values()), + redact=False, check_history_visibility_only=True, + ) + + if not filtered_extremities: + defer.returnValue(False) + # Check if we reached a point where we should start backfilling. sorted_extremeties_tuple = sorted( extremities.items(), diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 3981fe69ce..c762b58902 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -436,10 +436,11 @@ class EventCreationHandler(object): if event.is_state(): prev_state = yield self.deduplicate_state_event(event, context) - logger.info( - "Not bothering to persist duplicate state event %s", event.event_id, - ) if prev_state is not None: + logger.info( + "Not bothering to persist state event %s duplicated by %s", + event.event_id, prev_state.event_id, + ) defer.returnValue(prev_state) yield self.handle_new_client_event( diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index ba3856674d..37e87fc054 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -816,7 +816,7 @@ class PresenceHandler(object): if self.is_mine(observed_user): yield self.invite_presence(observed_user, observer_user) else: - yield self.federation.send_edu( + yield self.federation.build_and_send_edu( destination=observed_user.domain, edu_type="m.presence_invite", content={ @@ -836,7 +836,7 @@ class PresenceHandler(object): if self.is_mine(observer_user): yield self.accept_presence(observed_user, observer_user) else: - self.federation.send_edu( + self.federation.build_and_send_edu( destination=observer_user.domain, edu_type="m.presence_accept", content={ @@ -848,7 +848,7 @@ class PresenceHandler(object): state_dict = yield self.get_state(observed_user, as_event=False) state_dict = format_user_presence_state(state_dict, self.clock.time_msec()) - self.federation.send_edu( + self.federation.build_and_send_edu( destination=observer_user.domain, edu_type="m.presence", content={ diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 696469732c..1728089667 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -16,7 +16,6 @@ import logging from twisted.internet import defer -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import get_domain_from_id from ._base import BaseHandler @@ -39,31 +38,6 @@ class ReceiptsHandler(BaseHandler): self.state = hs.get_state_handler() @defer.inlineCallbacks - def received_client_receipt(self, room_id, receipt_type, user_id, - event_id): - """Called when a client tells us a local user has read up to the given - event_id in the room. - """ - receipt = { - "room_id": room_id, - "receipt_type": receipt_type, - "user_id": user_id, - "event_ids": [event_id], - "data": { - "ts": int(self.clock.time_msec()), - } - } - - is_new = yield self._handle_new_receipts([receipt]) - - if is_new: - # fire off a process in the background to send the receipt to - # remote servers - run_as_background_process( - 'push_receipts_to_remotes', self._push_remotes, receipt - ) - - @defer.inlineCallbacks def _received_remote_receipt(self, origin, content): """Called when we receive an EDU of type m.receipt from a remote HS. """ @@ -128,43 +102,54 @@ class ReceiptsHandler(BaseHandler): defer.returnValue(True) @defer.inlineCallbacks - def _push_remotes(self, receipt): - """Given a receipt, works out which remote servers should be - poked and pokes them. + def received_client_receipt(self, room_id, receipt_type, user_id, + event_id): + """Called when a client tells us a local user has read up to the given + event_id in the room. """ - try: - # TODO: optimise this to move some of the work to the workers. - room_id = receipt["room_id"] - receipt_type = receipt["receipt_type"] - user_id = receipt["user_id"] - event_ids = receipt["event_ids"] - data = receipt["data"] + receipt = { + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + "event_ids": [event_id], + "data": { + "ts": int(self.clock.time_msec()), + } + } - users = yield self.state.get_current_user_in_room(room_id) - remotedomains = set(get_domain_from_id(u) for u in users) - remotedomains = remotedomains.copy() - remotedomains.discard(self.server_name) - - logger.debug("Sending receipt to: %r", remotedomains) - - for domain in remotedomains: - self.federation.send_edu( - destination=domain, - edu_type="m.receipt", - content={ - room_id: { - receipt_type: { - user_id: { - "event_ids": event_ids, - "data": data, - } + is_new = yield self._handle_new_receipts([receipt]) + if not is_new: + return + + # Work out which remote servers should be poked and poke them. + + # TODO: optimise this to move some of the work to the workers. + data = receipt["data"] + + # XXX why does this not use state.get_current_hosts_in_room() ? + users = yield self.state.get_current_user_in_room(room_id) + remotedomains = set(get_domain_from_id(u) for u in users) + remotedomains = remotedomains.copy() + remotedomains.discard(self.server_name) + + logger.debug("Sending receipt to: %r", remotedomains) + + for domain in remotedomains: + self.federation.build_and_send_edu( + destination=domain, + edu_type="m.receipt", + content={ + room_id: { + receipt_type: { + user_id: { + "event_ids": [event_id], + "data": data, } - }, + } }, - key=(room_id, receipt_type, user_id), - ) - except Exception: - logger.exception("Error pushing receipts to remote servers") + }, + key=(room_id, receipt_type, user_id), + ) @defer.inlineCallbacks def get_receipts_for_room(self, room_id, to_key): diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index c0e06929bd..03130edc54 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -24,6 +24,7 @@ from synapse.api.errors import ( AuthError, Codes, InvalidCaptchaError, + LimitExceededError, RegistrationError, SynapseError, ) @@ -60,6 +61,7 @@ class RegistrationHandler(BaseHandler): self.user_directory_handler = hs.get_user_directory_handler() self.captcha_client = CaptchaServerHttpClient(hs) self.identity_handler = self.hs.get_handlers().identity_handler + self.ratelimiter = hs.get_registration_ratelimiter() self._next_generated_user_id = None @@ -149,6 +151,7 @@ class RegistrationHandler(BaseHandler): threepid=None, user_type=None, default_display_name=None, + address=None, ): """Registers a new client on the server. @@ -167,6 +170,7 @@ class RegistrationHandler(BaseHandler): api.constants.UserTypes, or None for a normal user. default_display_name (unicode|None): if set, the new user's displayname will be set to this. Defaults to 'localpart'. + address (str|None): the IP address used to perform the regitration. Returns: A tuple of (user_id, access_token). Raises: @@ -206,7 +210,7 @@ class RegistrationHandler(BaseHandler): token = None if generate_token: token = self.macaroon_gen.generate_access_token(user_id) - yield self._register_with_store( + yield self.register_with_store( user_id=user_id, token=token, password_hash=password_hash, @@ -215,6 +219,7 @@ class RegistrationHandler(BaseHandler): create_profile_with_displayname=default_display_name, admin=admin, user_type=user_type, + address=address, ) if self.hs.config.user_directory_search_all_users: @@ -238,12 +243,13 @@ class RegistrationHandler(BaseHandler): if default_display_name is None: default_display_name = localpart try: - yield self._register_with_store( + yield self.register_with_store( user_id=user_id, token=token, password_hash=password_hash, make_guest=make_guest, create_profile_with_displayname=default_display_name, + address=address, ) except SynapseError: # if user id is taken, just generate another @@ -337,7 +343,7 @@ class RegistrationHandler(BaseHandler): user_id, allowed_appservice=service ) - yield self._register_with_store( + yield self.register_with_store( user_id=user_id, password_hash="", appservice_id=service_id, @@ -513,7 +519,7 @@ class RegistrationHandler(BaseHandler): token = self.macaroon_gen.generate_access_token(user_id) if need_register: - yield self._register_with_store( + yield self.register_with_store( user_id=user_id, token=token, password_hash=password_hash, @@ -590,10 +596,10 @@ class RegistrationHandler(BaseHandler): ratelimit=False, ) - def _register_with_store(self, user_id, token=None, password_hash=None, - was_guest=False, make_guest=False, appservice_id=None, - create_profile_with_displayname=None, admin=False, - user_type=None): + def register_with_store(self, user_id, token=None, password_hash=None, + was_guest=False, make_guest=False, appservice_id=None, + create_profile_with_displayname=None, admin=False, + user_type=None, address=None): """Register user in the datastore. Args: @@ -612,10 +618,26 @@ class RegistrationHandler(BaseHandler): admin (boolean): is an admin user? user_type (str|None): type of user. One of the values from api.constants.UserTypes, or None for a normal user. + address (str|None): the IP address used to perform the regitration. Returns: Deferred """ + # Don't rate limit for app services + if appservice_id is None and address is not None: + time_now = self.clock.time() + + allowed, time_allowed = self.ratelimiter.can_do_action( + address, time_now_s=time_now, + rate_hz=self.hs.config.rc_registration_requests_per_second, + burst_count=self.hs.config.rc_registration_request_burst_count, + ) + + if not allowed: + raise LimitExceededError( + retry_after_ms=int(1000 * (time_allowed - time_now)), + ) + if self.hs.config.worker_app: return self._register_client( user_id=user_id, @@ -627,6 +649,7 @@ class RegistrationHandler(BaseHandler): create_profile_with_displayname=create_profile_with_displayname, admin=admin, user_type=user_type, + address=address, ) else: return self.store.register( diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index a61bbf9392..39df960c31 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -231,7 +231,7 @@ class TypingHandler(object): for domain in set(get_domain_from_id(u) for u in users): if domain != self.server_name: logger.debug("sending typing update to %s", domain) - self.federation.send_edu( + self.federation.build_and_send_edu( destination=domain, edu_type="m.typing", content={ diff --git a/synapse/notifier.py b/synapse/notifier.py index de02b1017e..ff589660da 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -178,8 +178,6 @@ class Notifier(object): self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS ) - self.replication_deferred = ObservableDeferred(defer.Deferred()) - # This is not a very cheap test to perform, but it's only executed # when rendering the metrics page, which is likely once per minute at # most when scraping it. @@ -205,7 +203,9 @@ class Notifier(object): def add_replication_callback(self, cb): """Add a callback that will be called when some new data is available. - Callback is not given any arguments. + Callback is not given any arguments. It should *not* return a Deferred - if + it needs to do any asynchronous work, a background thread should be started and + wrapped with run_as_background_process. """ self.replication_callbacks.append(cb) @@ -517,60 +517,5 @@ class Notifier(object): def notify_replication(self): """Notify the any replication listeners that there's a new event""" - with PreserveLoggingContext(): - deferred = self.replication_deferred - self.replication_deferred = ObservableDeferred(defer.Deferred()) - deferred.callback(None) - - # the callbacks may well outlast the current request, so we run - # them in the sentinel logcontext. - # - # (ideally it would be up to the callbacks to know if they were - # starting off background processes and drop the logcontext - # accordingly, but that requires more changes) - for cb in self.replication_callbacks: - cb() - - @defer.inlineCallbacks - def wait_for_replication(self, callback, timeout): - """Wait for an event to happen. - - Args: - callback: Gets called whenever an event happens. If this returns a - truthy value then ``wait_for_replication`` returns, otherwise - it waits for another event. - timeout: How many milliseconds to wait for callback return a truthy - value. - - Returns: - A deferred that resolves with the value returned by the callback. - """ - listener = _NotificationListener(None) - - end_time = self.clock.time_msec() + timeout - - while True: - listener.deferred = self.replication_deferred.observe() - result = yield callback() - if result: - break - - now = self.clock.time_msec() - if end_time <= now: - break - - listener.deferred = timeout_deferred( - listener.deferred, - timeout=(end_time - now) / 1000., - reactor=self.hs.get_reactor(), - ) - - try: - with PreserveLoggingContext(): - yield listener.deferred - except defer.TimeoutError: - break - except defer.CancelledError: - break - - defer.returnValue(result) + for cb in self.replication_callbacks: + cb() diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index 1d27c9221f..912a5ac341 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -33,11 +33,12 @@ class ReplicationRegisterServlet(ReplicationEndpoint): def __init__(self, hs): super(ReplicationRegisterServlet, self).__init__(hs) self.store = hs.get_datastore() + self.registration_handler = hs.get_registration_handler() @staticmethod def _serialize_payload( user_id, token, password_hash, was_guest, make_guest, appservice_id, - create_profile_with_displayname, admin, user_type, + create_profile_with_displayname, admin, user_type, address, ): """ Args: @@ -56,6 +57,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): admin (boolean): is an admin user? user_type (str|None): type of user. One of the values from api.constants.UserTypes, or None for a normal user. + address (str|None): the IP address used to perform the regitration. """ return { "token": token, @@ -66,13 +68,14 @@ class ReplicationRegisterServlet(ReplicationEndpoint): "create_profile_with_displayname": create_profile_with_displayname, "admin": admin, "user_type": user_type, + "address": address, } @defer.inlineCallbacks def _handle_request(self, request, user_id): content = parse_json_object_from_request(request) - yield self.store.register( + yield self.registration_handler.register_with_store( user_id=user_id, token=content["token"], password_hash=content["password_hash"], @@ -82,6 +85,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): create_profile_with_displayname=content["create_profile_with_displayname"], admin=content["admin"], user_type=content["user_type"], + address=content["address"] ) defer.returnValue((200, {})) diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 4f19fd35aa..4d59778863 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -13,15 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage import DataStore +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.storage.deviceinbox import DeviceInboxWorkerStore from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.stream_change_cache import StreamChangeCache -from ._base import BaseSlavedStore, __func__ -from ._slaved_id_tracker import SlavedIdTracker - -class SlavedDeviceInboxStore(BaseSlavedStore): +class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedDeviceInboxStore, self).__init__(db_conn, hs) self._device_inbox_id_gen = SlavedIdTracker( @@ -43,12 +42,6 @@ class SlavedDeviceInboxStore(BaseSlavedStore): expiry_ms=30 * 60 * 1000, ) - get_to_device_stream_token = __func__(DataStore.get_to_device_stream_token) - get_new_messages_for_device = __func__(DataStore.get_new_messages_for_device) - get_new_device_msgs_for_remote = __func__(DataStore.get_new_device_msgs_for_remote) - delete_messages_for_device = __func__(DataStore.delete_messages_for_device) - delete_device_msgs_for_remote = __func__(DataStore.delete_device_msgs_for_remote) - def stream_positions(self): result = super(SlavedDeviceInboxStore, self).stream_positions() result["to_device"] = self._device_inbox_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index ec2fd561cc..16c9a162c5 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -13,15 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage import DataStore -from synapse.storage.end_to_end_keys import EndToEndKeyStore +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.storage.devices import DeviceWorkerStore +from synapse.storage.end_to_end_keys import EndToEndKeyWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache -from ._base import BaseSlavedStore, __func__ -from ._slaved_id_tracker import SlavedIdTracker - -class SlavedDeviceStore(BaseSlavedStore): +class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedDeviceStore, self).__init__(db_conn, hs) @@ -38,17 +37,6 @@ class SlavedDeviceStore(BaseSlavedStore): "DeviceListFederationStreamChangeCache", device_list_max, ) - get_device_stream_token = __func__(DataStore.get_device_stream_token) - get_user_whose_devices_changed = __func__(DataStore.get_user_whose_devices_changed) - get_devices_by_remote = __func__(DataStore.get_devices_by_remote) - _get_devices_by_remote_txn = __func__(DataStore._get_devices_by_remote_txn) - _get_e2e_device_keys_txn = __func__(DataStore._get_e2e_device_keys_txn) - mark_as_sent_devices_by_remote = __func__(DataStore.mark_as_sent_devices_by_remote) - _mark_as_sent_devices_by_remote_txn = ( - __func__(DataStore._mark_as_sent_devices_by_remote_txn) - ) - count_e2e_one_time_keys = EndToEndKeyStore.__dict__["count_e2e_one_time_keys"] - def stream_positions(self): result = super(SlavedDeviceStore, self).stream_positions() result["device_lists"] = self._device_list_id_gen.get_current_token() @@ -58,14 +46,23 @@ class SlavedDeviceStore(BaseSlavedStore): if stream_name == "device_lists": self._device_list_id_gen.advance(token) for row in rows: - self._device_list_stream_cache.entity_has_changed( - row.user_id, token + self._invalidate_caches_for_devices( + token, row.user_id, row.destination, ) - - if row.destination: - self._device_list_federation_stream_cache.entity_has_changed( - row.destination, token - ) return super(SlavedDeviceStore, self).process_replication_rows( stream_name, token, rows ) + + def _invalidate_caches_for_devices(self, token, user_id, destination): + self._device_list_stream_cache.entity_has_changed( + user_id, token + ) + + if destination: + self._device_list_federation_stream_cache.entity_has_changed( + destination, token + ) + + self._get_cached_devices_for_user.invalidate((user_id,)) + self._get_cached_user_device.invalidate_many((user_id,)) + self.get_device_list_last_stream_id_for_remote.invalidate((user_id,)) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index f0200c1e98..45fc913c52 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -20,7 +20,7 @@ from ._slaved_id_tracker import SlavedIdTracker from .events import SlavedEventStore -class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore): +class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): def __init__(self, db_conn, hs): self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id", diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 82433a2aa9..0201cf1186 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -17,12 +17,14 @@ import hashlib import hmac import logging +import platform from six import text_type from six.moves import http_client from twisted.internet import defer +import synapse from synapse.api.constants import Membership, UserTypes from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.http.servlet import ( @@ -32,6 +34,7 @@ from synapse.http.servlet import ( parse_string, ) from synapse.types import UserID, create_requester +from synapse.util.versionstring import get_version_string from .base import ClientV1RestServlet, client_path_patterns @@ -66,6 +69,25 @@ class UsersRestServlet(ClientV1RestServlet): defer.returnValue((200, ret)) +class VersionServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns("/admin/server_version") + + @defer.inlineCallbacks + def on_GET(self, request): + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + + if not is_admin: + raise AuthError(403, "You are not a server admin") + + ret = { + 'server_version': get_version_string(synapse), + 'python_version': platform.python_version(), + } + + defer.returnValue((200, ret)) + + class UserRegisterServlet(ClientV1RestServlet): """ Attributes: @@ -763,3 +785,4 @@ def register_servlets(hs, http_server): QuarantineMediaInRoom(hs).register(http_server) ListMediaInRoom(hs).register(http_server) UserRegisterServlet(hs).register(http_server) + VersionServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 94cbba4303..6f34029431 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -25,7 +25,12 @@ from twisted.internet import defer import synapse import synapse.types from synapse.api.constants import LoginType -from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError +from synapse.api.errors import ( + Codes, + LimitExceededError, + SynapseError, + UnrecognizedRequestError, +) from synapse.config.server import is_threepid_reserved from synapse.http.servlet import ( RestServlet, @@ -191,18 +196,36 @@ class RegisterRestServlet(RestServlet): self.identity_handler = hs.get_handlers().identity_handler self.room_member_handler = hs.get_room_member_handler() self.macaroon_gen = hs.get_macaroon_generator() + self.ratelimiter = hs.get_registration_ratelimiter() + self.clock = hs.get_clock() @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): body = parse_json_object_from_request(request) + client_addr = request.getClientIP() + + time_now = self.clock.time() + + allowed, time_allowed = self.ratelimiter.can_do_action( + client_addr, time_now_s=time_now, + rate_hz=self.hs.config.rc_registration_requests_per_second, + burst_count=self.hs.config.rc_registration_request_burst_count, + update=False, + ) + + if not allowed: + raise LimitExceededError( + retry_after_ms=int(1000 * (time_allowed - time_now)), + ) + kind = b"user" if b"kind" in request.args: kind = request.args[b"kind"][0] if kind == b"guest": - ret = yield self._do_guest_registration(body) + ret = yield self._do_guest_registration(body, address=client_addr) defer.returnValue(ret) return elif kind != b"user": @@ -411,6 +434,7 @@ class RegisterRestServlet(RestServlet): guest_access_token=guest_access_token, generate_token=False, threepid=threepid, + address=client_addr, ) # Necessary due to auth checks prior to the threepid being # written to the db @@ -522,12 +546,13 @@ class RegisterRestServlet(RestServlet): defer.returnValue(result) @defer.inlineCallbacks - def _do_guest_registration(self, params): + def _do_guest_registration(self, params, address=None): if not self.hs.config.allow_guest_access: raise SynapseError(403, "Guest access is disabled") user_id, _ = yield self.registration_handler.register( generate_token=False, - make_guest=True + make_guest=True, + address=address, ) # we don't allow guests to specify their own device_id, because diff --git a/synapse/server.py b/synapse/server.py index 4d364fccce..72835e8c86 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -51,7 +51,7 @@ from synapse.handlers.acme import AcmeHandler from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.auth import AuthHandler, MacaroonGenerator from synapse.handlers.deactivate_account import DeactivateAccountHandler -from synapse.handlers.device import DeviceHandler +from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler from synapse.handlers.devicemessage import DeviceMessageHandler from synapse.handlers.e2e_keys import E2eKeysHandler from synapse.handlers.e2e_room_keys import E2eRoomKeysHandler @@ -206,6 +206,7 @@ class HomeServer(object): self.clock = Clock(reactor) self.distributor = Distributor() self.ratelimiter = Ratelimiter() + self.registration_ratelimiter = Ratelimiter() self.datastore = None @@ -251,6 +252,9 @@ class HomeServer(object): def get_ratelimiter(self): return self.ratelimiter + def get_registration_ratelimiter(self): + return self.registration_ratelimiter + def build_federation_client(self): return FederationClient(self) @@ -307,7 +311,10 @@ class HomeServer(object): return MacaroonGenerator(self) def build_device_handler(self): - return DeviceHandler(self) + if self.config.worker_app: + return DeviceWorkerHandler(self) + else: + return DeviceHandler(self) def build_device_message_handler(self): return DeviceMessageHandler(self) diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index e06b0bc56d..e6a42a53bb 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -19,14 +19,174 @@ from canonicaljson import json from twisted.internet import defer +from synapse.storage._base import SQLBaseStore +from synapse.storage.background_updates import BackgroundUpdateStore from synapse.util.caches.expiringcache import ExpiringCache -from .background_updates import BackgroundUpdateStore - logger = logging.getLogger(__name__) -class DeviceInboxStore(BackgroundUpdateStore): +class DeviceInboxWorkerStore(SQLBaseStore): + def get_to_device_stream_token(self): + return self._device_inbox_id_gen.get_current_token() + + def get_new_messages_for_device( + self, user_id, device_id, last_stream_id, current_stream_id, limit=100 + ): + """ + Args: + user_id(str): The recipient user_id. + device_id(str): The recipient device_id. + current_stream_id(int): The current position of the to device + message stream. + Returns: + Deferred ([dict], int): List of messages for the device and where + in the stream the messages got to. + """ + has_changed = self._device_inbox_stream_cache.has_entity_changed( + user_id, last_stream_id + ) + if not has_changed: + return defer.succeed(([], current_stream_id)) + + def get_new_messages_for_device_txn(txn): + sql = ( + "SELECT stream_id, message_json FROM device_inbox" + " WHERE user_id = ? AND device_id = ?" + " AND ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + " LIMIT ?" + ) + txn.execute(sql, ( + user_id, device_id, last_stream_id, current_stream_id, limit + )) + messages = [] + for row in txn: + stream_pos = row[0] + messages.append(json.loads(row[1])) + if len(messages) < limit: + stream_pos = current_stream_id + return (messages, stream_pos) + + return self.runInteraction( + "get_new_messages_for_device", get_new_messages_for_device_txn, + ) + + @defer.inlineCallbacks + def delete_messages_for_device(self, user_id, device_id, up_to_stream_id): + """ + Args: + user_id(str): The recipient user_id. + device_id(str): The recipient device_id. + up_to_stream_id(int): Where to delete messages up to. + Returns: + A deferred that resolves to the number of messages deleted. + """ + # If we have cached the last stream id we've deleted up to, we can + # check if there is likely to be anything that needs deleting + last_deleted_stream_id = self._last_device_delete_cache.get( + (user_id, device_id), None + ) + if last_deleted_stream_id: + has_changed = self._device_inbox_stream_cache.has_entity_changed( + user_id, last_deleted_stream_id + ) + if not has_changed: + defer.returnValue(0) + + def delete_messages_for_device_txn(txn): + sql = ( + "DELETE FROM device_inbox" + " WHERE user_id = ? AND device_id = ?" + " AND stream_id <= ?" + ) + txn.execute(sql, (user_id, device_id, up_to_stream_id)) + return txn.rowcount + + count = yield self.runInteraction( + "delete_messages_for_device", delete_messages_for_device_txn + ) + + # Update the cache, ensuring that we only ever increase the value + last_deleted_stream_id = self._last_device_delete_cache.get( + (user_id, device_id), 0 + ) + self._last_device_delete_cache[(user_id, device_id)] = max( + last_deleted_stream_id, up_to_stream_id + ) + + defer.returnValue(count) + + def get_new_device_msgs_for_remote( + self, destination, last_stream_id, current_stream_id, limit=100 + ): + """ + Args: + destination(str): The name of the remote server. + last_stream_id(int|long): The last position of the device message stream + that the server sent up to. + current_stream_id(int|long): The current position of the device + message stream. + Returns: + Deferred ([dict], int|long): List of messages for the device and where + in the stream the messages got to. + """ + + has_changed = self._device_federation_outbox_stream_cache.has_entity_changed( + destination, last_stream_id + ) + if not has_changed or last_stream_id == current_stream_id: + return defer.succeed(([], current_stream_id)) + + def get_new_messages_for_remote_destination_txn(txn): + sql = ( + "SELECT stream_id, messages_json FROM device_federation_outbox" + " WHERE destination = ?" + " AND ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + " LIMIT ?" + ) + txn.execute(sql, ( + destination, last_stream_id, current_stream_id, limit + )) + messages = [] + for row in txn: + stream_pos = row[0] + messages.append(json.loads(row[1])) + if len(messages) < limit: + stream_pos = current_stream_id + return (messages, stream_pos) + + return self.runInteraction( + "get_new_device_msgs_for_remote", + get_new_messages_for_remote_destination_txn, + ) + + def delete_device_msgs_for_remote(self, destination, up_to_stream_id): + """Used to delete messages when the remote destination acknowledges + their receipt. + + Args: + destination(str): The destination server_name + up_to_stream_id(int): Where to delete messages up to. + Returns: + A deferred that resolves when the messages have been deleted. + """ + def delete_messages_for_remote_destination_txn(txn): + sql = ( + "DELETE FROM device_federation_outbox" + " WHERE destination = ?" + " AND stream_id <= ?" + ) + txn.execute(sql, (destination, up_to_stream_id)) + + return self.runInteraction( + "delete_device_msgs_for_remote", + delete_messages_for_remote_destination_txn + ) + + +class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" def __init__(self, db_conn, hs): @@ -220,93 +380,6 @@ class DeviceInboxStore(BackgroundUpdateStore): txn.executemany(sql, rows) - def get_new_messages_for_device( - self, user_id, device_id, last_stream_id, current_stream_id, limit=100 - ): - """ - Args: - user_id(str): The recipient user_id. - device_id(str): The recipient device_id. - current_stream_id(int): The current position of the to device - message stream. - Returns: - Deferred ([dict], int): List of messages for the device and where - in the stream the messages got to. - """ - has_changed = self._device_inbox_stream_cache.has_entity_changed( - user_id, last_stream_id - ) - if not has_changed: - return defer.succeed(([], current_stream_id)) - - def get_new_messages_for_device_txn(txn): - sql = ( - "SELECT stream_id, message_json FROM device_inbox" - " WHERE user_id = ? AND device_id = ?" - " AND ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC" - " LIMIT ?" - ) - txn.execute(sql, ( - user_id, device_id, last_stream_id, current_stream_id, limit - )) - messages = [] - for row in txn: - stream_pos = row[0] - messages.append(json.loads(row[1])) - if len(messages) < limit: - stream_pos = current_stream_id - return (messages, stream_pos) - - return self.runInteraction( - "get_new_messages_for_device", get_new_messages_for_device_txn, - ) - - @defer.inlineCallbacks - def delete_messages_for_device(self, user_id, device_id, up_to_stream_id): - """ - Args: - user_id(str): The recipient user_id. - device_id(str): The recipient device_id. - up_to_stream_id(int): Where to delete messages up to. - Returns: - A deferred that resolves to the number of messages deleted. - """ - # If we have cached the last stream id we've deleted up to, we can - # check if there is likely to be anything that needs deleting - last_deleted_stream_id = self._last_device_delete_cache.get( - (user_id, device_id), None - ) - if last_deleted_stream_id: - has_changed = self._device_inbox_stream_cache.has_entity_changed( - user_id, last_deleted_stream_id - ) - if not has_changed: - defer.returnValue(0) - - def delete_messages_for_device_txn(txn): - sql = ( - "DELETE FROM device_inbox" - " WHERE user_id = ? AND device_id = ?" - " AND stream_id <= ?" - ) - txn.execute(sql, (user_id, device_id, up_to_stream_id)) - return txn.rowcount - - count = yield self.runInteraction( - "delete_messages_for_device", delete_messages_for_device_txn - ) - - # Update the cache, ensuring that we only ever increase the value - last_deleted_stream_id = self._last_device_delete_cache.get( - (user_id, device_id), 0 - ) - self._last_device_delete_cache[(user_id, device_id)] = max( - last_deleted_stream_id, up_to_stream_id - ) - - defer.returnValue(count) - def get_all_new_device_messages(self, last_pos, current_pos, limit): """ Args: @@ -351,77 +424,6 @@ class DeviceInboxStore(BackgroundUpdateStore): "get_all_new_device_messages", get_all_new_device_messages_txn ) - def get_to_device_stream_token(self): - return self._device_inbox_id_gen.get_current_token() - - def get_new_device_msgs_for_remote( - self, destination, last_stream_id, current_stream_id, limit=100 - ): - """ - Args: - destination(str): The name of the remote server. - last_stream_id(int|long): The last position of the device message stream - that the server sent up to. - current_stream_id(int|long): The current position of the device - message stream. - Returns: - Deferred ([dict], int|long): List of messages for the device and where - in the stream the messages got to. - """ - - has_changed = self._device_federation_outbox_stream_cache.has_entity_changed( - destination, last_stream_id - ) - if not has_changed or last_stream_id == current_stream_id: - return defer.succeed(([], current_stream_id)) - - def get_new_messages_for_remote_destination_txn(txn): - sql = ( - "SELECT stream_id, messages_json FROM device_federation_outbox" - " WHERE destination = ?" - " AND ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC" - " LIMIT ?" - ) - txn.execute(sql, ( - destination, last_stream_id, current_stream_id, limit - )) - messages = [] - for row in txn: - stream_pos = row[0] - messages.append(json.loads(row[1])) - if len(messages) < limit: - stream_pos = current_stream_id - return (messages, stream_pos) - - return self.runInteraction( - "get_new_device_msgs_for_remote", - get_new_messages_for_remote_destination_txn, - ) - - def delete_device_msgs_for_remote(self, destination, up_to_stream_id): - """Used to delete messages when the remote destination acknowledges - their receipt. - - Args: - destination(str): The destination server_name - up_to_stream_id(int): Where to delete messages up to. - Returns: - A deferred that resolves when the messages have been deleted. - """ - def delete_messages_for_remote_destination_txn(txn): - sql = ( - "DELETE FROM device_federation_outbox" - " WHERE destination = ?" - " AND stream_id <= ?" - ) - txn.execute(sql, (destination, up_to_stream_id)) - - return self.runInteraction( - "delete_device_msgs_for_remote", - delete_messages_for_remote_destination_txn - ) - @defer.inlineCallbacks def _background_drop_index_device_inbox(self, progress, batch_size): def reindex_txn(conn): diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index ecdab34e7d..e716dc1437 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -22,11 +22,10 @@ from twisted.internet import defer from synapse.api.errors import StoreError from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import Cache, SQLBaseStore, db_to_json from synapse.storage.background_updates import BackgroundUpdateStore from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList -from ._base import Cache, db_to_json - logger = logging.getLogger(__name__) DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( @@ -34,7 +33,343 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( ) -class DeviceStore(BackgroundUpdateStore): +class DeviceWorkerStore(SQLBaseStore): + def get_device(self, user_id, device_id): + """Retrieve a device. + + Args: + user_id (str): The ID of the user which owns the device + device_id (str): The ID of the device to retrieve + Returns: + defer.Deferred for a dict containing the device information + Raises: + StoreError: if the device is not found + """ + return self._simple_select_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + retcols=("user_id", "device_id", "display_name"), + desc="get_device", + ) + + @defer.inlineCallbacks + def get_devices_by_user(self, user_id): + """Retrieve all of a user's registered devices. + + Args: + user_id (str): + Returns: + defer.Deferred: resolves to a dict from device_id to a dict + containing "device_id", "user_id" and "display_name" for each + device. + """ + devices = yield self._simple_select_list( + table="devices", + keyvalues={"user_id": user_id}, + retcols=("user_id", "device_id", "display_name"), + desc="get_devices_by_user" + ) + + defer.returnValue({d["device_id"]: d for d in devices}) + + def get_devices_by_remote(self, destination, from_stream_id): + """Get stream of updates to send to remote servers + + Returns: + (int, list[dict]): current stream id and list of updates + """ + now_stream_id = self._device_list_id_gen.get_current_token() + + has_changed = self._device_list_federation_stream_cache.has_entity_changed( + destination, int(from_stream_id) + ) + if not has_changed: + return (now_stream_id, []) + + return self.runInteraction( + "get_devices_by_remote", self._get_devices_by_remote_txn, + destination, from_stream_id, now_stream_id, + ) + + def _get_devices_by_remote_txn(self, txn, destination, from_stream_id, + now_stream_id): + sql = """ + SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes + WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? + GROUP BY user_id, device_id + LIMIT 20 + """ + txn.execute( + sql, (destination, from_stream_id, now_stream_id, False) + ) + + # maps (user_id, device_id) -> stream_id + query_map = {(r[0], r[1]): r[2] for r in txn} + if not query_map: + return (now_stream_id, []) + + if len(query_map) >= 20: + now_stream_id = max(stream_id for stream_id in itervalues(query_map)) + + devices = self._get_e2e_device_keys_txn( + txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True + ) + + prev_sent_id_sql = """ + SELECT coalesce(max(stream_id), 0) as stream_id + FROM device_lists_outbound_last_success + WHERE destination = ? AND user_id = ? AND stream_id <= ? + """ + + results = [] + for user_id, user_devices in iteritems(devices): + # The prev_id for the first row is always the last row before + # `from_stream_id` + txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id)) + rows = txn.fetchall() + prev_id = rows[0][0] + for device_id, device in iteritems(user_devices): + stream_id = query_map[(user_id, device_id)] + result = { + "user_id": user_id, + "device_id": device_id, + "prev_id": [prev_id] if prev_id else [], + "stream_id": stream_id, + } + + prev_id = stream_id + + if device is not None: + key_json = device.get("key_json", None) + if key_json: + result["keys"] = db_to_json(key_json) + device_display_name = device.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + else: + result["deleted"] = True + + results.append(result) + + return (now_stream_id, results) + + def mark_as_sent_devices_by_remote(self, destination, stream_id): + """Mark that updates have successfully been sent to the destination. + """ + return self.runInteraction( + "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, + destination, stream_id, + ) + + def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): + # We update the device_lists_outbound_last_success with the successfully + # poked users. We do the join to see which users need to be inserted and + # which updated. + sql = """ + SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL) + FROM device_lists_outbound_pokes as o + LEFT JOIN device_lists_outbound_last_success as s + USING (destination, user_id) + WHERE destination = ? AND o.stream_id <= ? + GROUP BY user_id + """ + txn.execute(sql, (destination, stream_id,)) + rows = txn.fetchall() + + sql = """ + UPDATE device_lists_outbound_last_success + SET stream_id = ? + WHERE destination = ? AND user_id = ? + """ + txn.executemany( + sql, ((row[1], destination, row[0],) for row in rows if row[2]) + ) + + sql = """ + INSERT INTO device_lists_outbound_last_success + (destination, user_id, stream_id) VALUES (?, ?, ?) + """ + txn.executemany( + sql, ((destination, row[0], row[1],) for row in rows if not row[2]) + ) + + # Delete all sent outbound pokes + sql = """ + DELETE FROM device_lists_outbound_pokes + WHERE destination = ? AND stream_id <= ? + """ + txn.execute(sql, (destination, stream_id,)) + + def get_device_stream_token(self): + return self._device_list_id_gen.get_current_token() + + @defer.inlineCallbacks + def get_user_devices_from_cache(self, query_list): + """Get the devices (and keys if any) for remote users from the cache. + + Args: + query_list(list): List of (user_id, device_ids), if device_ids is + falsey then return all device ids for that user. + + Returns: + (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is + a set of user_ids and results_map is a mapping of + user_id -> device_id -> device_info + """ + user_ids = set(user_id for user_id, _ in query_list) + user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids)) + user_ids_in_cache = set( + user_id for user_id, stream_id in user_map.items() if stream_id + ) + user_ids_not_in_cache = user_ids - user_ids_in_cache + + results = {} + for user_id, device_id in query_list: + if user_id not in user_ids_in_cache: + continue + + if device_id: + device = yield self._get_cached_user_device(user_id, device_id) + results.setdefault(user_id, {})[device_id] = device + else: + results[user_id] = yield self._get_cached_devices_for_user(user_id) + + defer.returnValue((user_ids_not_in_cache, results)) + + @cachedInlineCallbacks(num_args=2, tree=True) + def _get_cached_user_device(self, user_id, device_id): + content = yield self._simple_select_one_onecol( + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + retcol="content", + desc="_get_cached_user_device", + ) + defer.returnValue(db_to_json(content)) + + @cachedInlineCallbacks() + def _get_cached_devices_for_user(self, user_id): + devices = yield self._simple_select_list( + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + }, + retcols=("device_id", "content"), + desc="_get_cached_devices_for_user", + ) + defer.returnValue({ + device["device_id"]: db_to_json(device["content"]) + for device in devices + }) + + def get_devices_with_keys_by_user(self, user_id): + """Get all devices (with any device keys) for a user + + Returns: + (stream_id, devices) + """ + return self.runInteraction( + "get_devices_with_keys_by_user", + self._get_devices_with_keys_by_user_txn, user_id, + ) + + def _get_devices_with_keys_by_user_txn(self, txn, user_id): + now_stream_id = self._device_list_id_gen.get_current_token() + + devices = self._get_e2e_device_keys_txn( + txn, [(user_id, None)], include_all_devices=True + ) + + if devices: + user_devices = devices[user_id] + results = [] + for device_id, device in iteritems(user_devices): + result = { + "device_id": device_id, + } + + key_json = device.get("key_json", None) + if key_json: + result["keys"] = db_to_json(key_json) + device_display_name = device.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + + results.append(result) + + return now_stream_id, results + + return now_stream_id, [] + + @defer.inlineCallbacks + def get_user_whose_devices_changed(self, from_key): + """Get set of users whose devices have changed since `from_key`. + """ + from_key = int(from_key) + changed = self._device_list_stream_cache.get_all_entities_changed(from_key) + if changed is not None: + defer.returnValue(set(changed)) + + sql = """ + SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ? + """ + rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key) + defer.returnValue(set(row[0] for row in rows)) + + def get_all_device_list_changes_for_remotes(self, from_key, to_key): + """Return a list of `(stream_id, user_id, destination)` which is the + combined list of changes to devices, and which destinations need to be + poked. `destination` may be None if no destinations need to be poked. + """ + # We do a group by here as there can be a large number of duplicate + # entries, since we throw away device IDs. + sql = """ + SELECT MAX(stream_id) AS stream_id, user_id, destination + FROM device_lists_stream + LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id) + WHERE ? < stream_id AND stream_id <= ? + GROUP BY user_id, destination + """ + return self._execute( + "get_all_device_list_changes_for_remotes", None, + sql, from_key, to_key + ) + + @cached(max_entries=10000) + def get_device_list_last_stream_id_for_remote(self, user_id): + """Get the last stream_id we got for a user. May be None if we haven't + got any information for them. + """ + return self._simple_select_one_onecol( + table="device_lists_remote_extremeties", + keyvalues={"user_id": user_id}, + retcol="stream_id", + desc="get_device_list_last_stream_id_for_remote", + allow_none=True, + ) + + @cachedList(cached_method_name="get_device_list_last_stream_id_for_remote", + list_name="user_ids", inlineCallbacks=True) + def get_device_list_last_stream_id_for_remotes(self, user_ids): + rows = yield self._simple_select_many_batch( + table="device_lists_remote_extremeties", + column="user_id", + iterable=user_ids, + retcols=("user_id", "stream_id",), + desc="get_device_list_last_stream_id_for_remotes", + ) + + results = {user_id: None for user_id in user_ids} + results.update({ + row["user_id"]: row["stream_id"] for row in rows + }) + + defer.returnValue(results) + + +class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): def __init__(self, db_conn, hs): super(DeviceStore, self).__init__(db_conn, hs) @@ -121,24 +456,6 @@ class DeviceStore(BackgroundUpdateStore): initial_device_display_name, e) raise StoreError(500, "Problem storing device.") - def get_device(self, user_id, device_id): - """Retrieve a device. - - Args: - user_id (str): The ID of the user which owns the device - device_id (str): The ID of the device to retrieve - Returns: - defer.Deferred for a dict containing the device information - Raises: - StoreError: if the device is not found - """ - return self._simple_select_one( - table="devices", - keyvalues={"user_id": user_id, "device_id": device_id}, - retcols=("user_id", "device_id", "display_name"), - desc="get_device", - ) - @defer.inlineCallbacks def delete_device(self, user_id, device_id): """Delete a device. @@ -203,57 +520,6 @@ class DeviceStore(BackgroundUpdateStore): ) @defer.inlineCallbacks - def get_devices_by_user(self, user_id): - """Retrieve all of a user's registered devices. - - Args: - user_id (str): - Returns: - defer.Deferred: resolves to a dict from device_id to a dict - containing "device_id", "user_id" and "display_name" for each - device. - """ - devices = yield self._simple_select_list( - table="devices", - keyvalues={"user_id": user_id}, - retcols=("user_id", "device_id", "display_name"), - desc="get_devices_by_user" - ) - - defer.returnValue({d["device_id"]: d for d in devices}) - - @cached(max_entries=10000) - def get_device_list_last_stream_id_for_remote(self, user_id): - """Get the last stream_id we got for a user. May be None if we haven't - got any information for them. - """ - return self._simple_select_one_onecol( - table="device_lists_remote_extremeties", - keyvalues={"user_id": user_id}, - retcol="stream_id", - desc="get_device_list_remote_extremity", - allow_none=True, - ) - - @cachedList(cached_method_name="get_device_list_last_stream_id_for_remote", - list_name="user_ids", inlineCallbacks=True) - def get_device_list_last_stream_id_for_remotes(self, user_ids): - rows = yield self._simple_select_many_batch( - table="device_lists_remote_extremeties", - column="user_id", - iterable=user_ids, - retcols=("user_id", "stream_id",), - desc="get_user_devices_from_cache", - ) - - results = {user_id: None for user_id in user_ids} - results.update({ - row["user_id"]: row["stream_id"] for row in rows - }) - - defer.returnValue(results) - - @defer.inlineCallbacks def mark_remote_user_device_list_as_unsubscribed(self, user_id): """Mark that we no longer track device lists for remote user. """ @@ -405,268 +671,6 @@ class DeviceStore(BackgroundUpdateStore): lock=False, ) - def get_devices_by_remote(self, destination, from_stream_id): - """Get stream of updates to send to remote servers - - Returns: - (int, list[dict]): current stream id and list of updates - """ - now_stream_id = self._device_list_id_gen.get_current_token() - - has_changed = self._device_list_federation_stream_cache.has_entity_changed( - destination, int(from_stream_id) - ) - if not has_changed: - return (now_stream_id, []) - - return self.runInteraction( - "get_devices_by_remote", self._get_devices_by_remote_txn, - destination, from_stream_id, now_stream_id, - ) - - def _get_devices_by_remote_txn(self, txn, destination, from_stream_id, - now_stream_id): - sql = """ - SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes - WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? - GROUP BY user_id, device_id - LIMIT 20 - """ - txn.execute( - sql, (destination, from_stream_id, now_stream_id, False) - ) - - # maps (user_id, device_id) -> stream_id - query_map = {(r[0], r[1]): r[2] for r in txn} - if not query_map: - return (now_stream_id, []) - - if len(query_map) >= 20: - now_stream_id = max(stream_id for stream_id in itervalues(query_map)) - - devices = self._get_e2e_device_keys_txn( - txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True - ) - - prev_sent_id_sql = """ - SELECT coalesce(max(stream_id), 0) as stream_id - FROM device_lists_outbound_last_success - WHERE destination = ? AND user_id = ? AND stream_id <= ? - """ - - results = [] - for user_id, user_devices in iteritems(devices): - # The prev_id for the first row is always the last row before - # `from_stream_id` - txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id)) - rows = txn.fetchall() - prev_id = rows[0][0] - for device_id, device in iteritems(user_devices): - stream_id = query_map[(user_id, device_id)] - result = { - "user_id": user_id, - "device_id": device_id, - "prev_id": [prev_id] if prev_id else [], - "stream_id": stream_id, - } - - prev_id = stream_id - - if device is not None: - key_json = device.get("key_json", None) - if key_json: - result["keys"] = db_to_json(key_json) - device_display_name = device.get("device_display_name", None) - if device_display_name: - result["device_display_name"] = device_display_name - else: - result["deleted"] = True - - results.append(result) - - return (now_stream_id, results) - - @defer.inlineCallbacks - def get_user_devices_from_cache(self, query_list): - """Get the devices (and keys if any) for remote users from the cache. - - Args: - query_list(list): List of (user_id, device_ids), if device_ids is - falsey then return all device ids for that user. - - Returns: - (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is - a set of user_ids and results_map is a mapping of - user_id -> device_id -> device_info - """ - user_ids = set(user_id for user_id, _ in query_list) - user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids)) - user_ids_in_cache = set( - user_id for user_id, stream_id in user_map.items() if stream_id - ) - user_ids_not_in_cache = user_ids - user_ids_in_cache - - results = {} - for user_id, device_id in query_list: - if user_id not in user_ids_in_cache: - continue - - if device_id: - device = yield self._get_cached_user_device(user_id, device_id) - results.setdefault(user_id, {})[device_id] = device - else: - results[user_id] = yield self._get_cached_devices_for_user(user_id) - - defer.returnValue((user_ids_not_in_cache, results)) - - @cachedInlineCallbacks(num_args=2, tree=True) - def _get_cached_user_device(self, user_id, device_id): - content = yield self._simple_select_one_onecol( - table="device_lists_remote_cache", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - }, - retcol="content", - desc="_get_cached_user_device", - ) - defer.returnValue(db_to_json(content)) - - @cachedInlineCallbacks() - def _get_cached_devices_for_user(self, user_id): - devices = yield self._simple_select_list( - table="device_lists_remote_cache", - keyvalues={ - "user_id": user_id, - }, - retcols=("device_id", "content"), - desc="_get_cached_devices_for_user", - ) - defer.returnValue({ - device["device_id"]: db_to_json(device["content"]) - for device in devices - }) - - def get_devices_with_keys_by_user(self, user_id): - """Get all devices (with any device keys) for a user - - Returns: - (stream_id, devices) - """ - return self.runInteraction( - "get_devices_with_keys_by_user", - self._get_devices_with_keys_by_user_txn, user_id, - ) - - def _get_devices_with_keys_by_user_txn(self, txn, user_id): - now_stream_id = self._device_list_id_gen.get_current_token() - - devices = self._get_e2e_device_keys_txn( - txn, [(user_id, None)], include_all_devices=True - ) - - if devices: - user_devices = devices[user_id] - results = [] - for device_id, device in iteritems(user_devices): - result = { - "device_id": device_id, - } - - key_json = device.get("key_json", None) - if key_json: - result["keys"] = db_to_json(key_json) - device_display_name = device.get("device_display_name", None) - if device_display_name: - result["device_display_name"] = device_display_name - - results.append(result) - - return now_stream_id, results - - return now_stream_id, [] - - def mark_as_sent_devices_by_remote(self, destination, stream_id): - """Mark that updates have successfully been sent to the destination. - """ - return self.runInteraction( - "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, - destination, stream_id, - ) - - def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): - # We update the device_lists_outbound_last_success with the successfully - # poked users. We do the join to see which users need to be inserted and - # which updated. - sql = """ - SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL) - FROM device_lists_outbound_pokes as o - LEFT JOIN device_lists_outbound_last_success as s - USING (destination, user_id) - WHERE destination = ? AND o.stream_id <= ? - GROUP BY user_id - """ - txn.execute(sql, (destination, stream_id,)) - rows = txn.fetchall() - - sql = """ - UPDATE device_lists_outbound_last_success - SET stream_id = ? - WHERE destination = ? AND user_id = ? - """ - txn.executemany( - sql, ((row[1], destination, row[0],) for row in rows if row[2]) - ) - - sql = """ - INSERT INTO device_lists_outbound_last_success - (destination, user_id, stream_id) VALUES (?, ?, ?) - """ - txn.executemany( - sql, ((destination, row[0], row[1],) for row in rows if not row[2]) - ) - - # Delete all sent outbound pokes - sql = """ - DELETE FROM device_lists_outbound_pokes - WHERE destination = ? AND stream_id <= ? - """ - txn.execute(sql, (destination, stream_id,)) - - @defer.inlineCallbacks - def get_user_whose_devices_changed(self, from_key): - """Get set of users whose devices have changed since `from_key`. - """ - from_key = int(from_key) - changed = self._device_list_stream_cache.get_all_entities_changed(from_key) - if changed is not None: - defer.returnValue(set(changed)) - - sql = """ - SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ? - """ - rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key) - defer.returnValue(set(row[0] for row in rows)) - - def get_all_device_list_changes_for_remotes(self, from_key, to_key): - """Return a list of `(stream_id, user_id, destination)` which is the - combined list of changes to devices, and which destinations need to be - poked. `destination` may be None if no destinations need to be poked. - """ - # We do a group by here as there can be a large number of duplicate - # entries, since we throw away device IDs. - sql = """ - SELECT MAX(stream_id) AS stream_id, user_id, destination - FROM device_lists_stream - LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id) - WHERE ? < stream_id AND stream_id <= ? - GROUP BY user_id, destination - """ - return self._execute( - "get_all_device_list_changes_for_remotes", None, - sql, from_key, to_key - ) - @defer.inlineCallbacks def add_device_change_to_streams(self, user_id, device_ids, hosts): """Persist that a user's devices have been updated, and which hosts @@ -732,9 +736,6 @@ class DeviceStore(BackgroundUpdateStore): ] ) - def get_device_stream_token(self): - return self._device_list_id_gen.get_current_token() - def _prune_old_outbound_device_pokes(self): """Delete old entries out of the device_lists_outbound_pokes to ensure that we don't fill up due to dead servers. We keep one entry per diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 2a0f6cfca9..e381e472a2 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -23,49 +23,7 @@ from synapse.util.caches.descriptors import cached from ._base import SQLBaseStore, db_to_json -class EndToEndKeyStore(SQLBaseStore): - def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): - """Stores device keys for a device. Returns whether there was a change - or the keys were already in the database. - """ - def _set_e2e_device_keys_txn(txn): - old_key_json = self._simple_select_one_onecol_txn( - txn, - table="e2e_device_keys_json", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - }, - retcol="key_json", - allow_none=True, - ) - - # In py3 we need old_key_json to match new_key_json type. The DB - # returns unicode while encode_canonical_json returns bytes. - new_key_json = encode_canonical_json(device_keys).decode("utf-8") - - if old_key_json == new_key_json: - return False - - self._simple_upsert_txn( - txn, - table="e2e_device_keys_json", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - }, - values={ - "ts_added_ms": time_now, - "key_json": new_key_json, - } - ) - - return True - - return self.runInteraction( - "set_e2e_device_keys", _set_e2e_device_keys_txn - ) - +class EndToEndKeyWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_e2e_device_keys( self, query_list, include_all_devices=False, @@ -238,6 +196,50 @@ class EndToEndKeyStore(SQLBaseStore): "count_e2e_one_time_keys", _count_e2e_one_time_keys ) + +class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): + def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): + """Stores device keys for a device. Returns whether there was a change + or the keys were already in the database. + """ + def _set_e2e_device_keys_txn(txn): + old_key_json = self._simple_select_one_onecol_txn( + txn, + table="e2e_device_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + retcol="key_json", + allow_none=True, + ) + + # In py3 we need old_key_json to match new_key_json type. The DB + # returns unicode while encode_canonical_json returns bytes. + new_key_json = encode_canonical_json(device_keys).decode("utf-8") + + if old_key_json == new_key_json: + return False + + self._simple_upsert_txn( + txn, + table="e2e_device_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + values={ + "ts_added_ms": time_now, + "key_json": new_key_json, + } + ) + + return True + + return self.runInteraction( + "set_e2e_device_keys", _set_e2e_device_keys_txn + ) + def claim_e2e_one_time_keys(self, query_list): """Take a list of one time keys out of the database""" def _claim_e2e_one_time_keys(txn): diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 38809ed0fc..a8d90456e3 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -442,6 +442,28 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, event_results.reverse() return event_results + @defer.inlineCallbacks + def get_successor_events(self, event_ids): + """Fetch all events that have the given events as a prev event + + Args: + event_ids (iterable[str]) + + Returns: + Deferred[list[str]] + """ + rows = yield self._simple_select_many_batch( + table="event_edges", + column="prev_event_id", + iterable=event_ids, + retcols=("event_id",), + desc="get_successor_events" + ) + + defer.returnValue([ + row["event_id"] for row in rows + ]) + class EventFederationStore(EventFederationWorkerStore): """ Responsible for storing and serving up the various graphs associated diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 0ac665e967..0fd1ccc40a 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -346,15 +346,23 @@ class ReceiptsStore(ReceiptsWorkerStore): def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, user_id, event_id, data, stream_id): + """Inserts a read-receipt into the database if it's newer than the current RR + + Returns: int|None + None if the RR is older than the current RR + otherwise, the rx timestamp of the event that the RR corresponds to + (or 0 if the event is unknown) + """ res = self._simple_select_one_txn( txn, table="events", - retcols=["topological_ordering", "stream_ordering"], + retcols=["stream_ordering", "received_ts"], keyvalues={"event_id": event_id}, allow_none=True ) stream_ordering = int(res["stream_ordering"]) if res else None + rx_ts = res["received_ts"] if res else 0 # We don't want to clobber receipts for more recent events, so we # have to compare orderings of existing receipts @@ -373,7 +381,7 @@ class ReceiptsStore(ReceiptsWorkerStore): "one for later event %s", event_id, eid, ) - return False + return None txn.call_after( self.get_receipts_for_room.invalidate, (room_id, receipt_type) @@ -429,7 +437,7 @@ class ReceiptsStore(ReceiptsWorkerStore): stream_ordering=stream_ordering, ) - return True + return rx_ts @defer.inlineCallbacks def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data): @@ -466,7 +474,7 @@ class ReceiptsStore(ReceiptsWorkerStore): stream_id_manager = self._receipts_id_gen.get_next() with stream_id_manager as stream_id: - have_persisted = yield self.runInteraction( + event_ts = yield self.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, room_id, receipt_type, user_id, linearized_event_id, @@ -474,8 +482,14 @@ class ReceiptsStore(ReceiptsWorkerStore): stream_id=stream_id, ) - if not have_persisted: - defer.returnValue(None) + if event_ts is None: + defer.returnValue(None) + + now = self._clock.time_msec() + logger.debug( + "RR for event %s in %s (%i ms old)", + linearized_event_id, room_id, now - event_ts, + ) yield self.insert_graph_receipt( room_id, receipt_type, user_id, event_ids, data diff --git a/synapse/visibility.py b/synapse/visibility.py index 0281a7c919..efec21673b 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -216,28 +216,36 @@ def filter_events_for_client(store, user_id, events, is_peeking=False, @defer.inlineCallbacks -def filter_events_for_server(store, server_name, events): - # Whatever else we do, we need to check for senders which have requested - # erasure of their data. - erased_senders = yield store.are_users_erased( - (e.sender for e in events), - ) +def filter_events_for_server(store, server_name, events, redact=True, + check_history_visibility_only=False): + """Filter a list of events based on whether given server is allowed to + see them. - def redact_disallowed(event, state): - # if the sender has been gdpr17ed, always return a redacted - # copy of the event. - if erased_senders[event.sender]: + Args: + store (DataStore) + server_name (str) + events (iterable[FrozenEvent]) + redact (bool): Whether to return a redacted version of the event, or + to filter them out entirely. + check_history_visibility_only (bool): Whether to only check the + history visibility, rather than things like if the sender has been + erased. This is used e.g. during pagination to decide whether to + backfill or not. + + Returns + Deferred[list[FrozenEvent]] + """ + + def is_sender_erased(event, erased_senders): + if erased_senders and erased_senders[event.sender]: logger.info( "Sender of %s has been erased, redacting", event.event_id, ) - return prune_event(event) - - # state will be None if we decided we didn't need to filter by - # room membership. - if not state: - return event + return True + return False + def check_event_is_visible(event, state): history = state.get((EventTypes.RoomHistoryVisibility, ''), None) if history: visibility = history.content.get("history_visibility", "shared") @@ -259,17 +267,17 @@ def filter_events_for_server(store, server_name, events): memtype = ev.membership if memtype == Membership.JOIN: - return event + return True elif memtype == Membership.INVITE: if visibility == "invited": - return event + return True else: # server has no users in the room: redact - return prune_event(event) + return False - return event + return True - # Next lets check to see if all the events have a history visibility + # Lets check to see if all the events have a history visibility # of "shared" or "world_readable". If thats the case then we don't # need to check membership (as we know the server is in the room). event_to_state_ids = yield store.get_state_ids_for_events( @@ -296,16 +304,31 @@ def filter_events_for_server(store, server_name, events): for e in itervalues(event_map) ) + if not check_history_visibility_only: + erased_senders = yield store.are_users_erased( + (e.sender for e in events), + ) + else: + # We don't want to check whether users are erased, which is equivalent + # to no users having been erased. + erased_senders = {} + if all_open: # all the history_visibility state affecting these events is open, so # we don't need to filter by membership state. We *do* need to check # for user erasure, though. if erased_senders: - events = [ - redact_disallowed(e, None) - for e in events - ] + to_return = [] + for e in events: + if not is_sender_erased(e, erased_senders): + to_return.append(e) + elif redact: + to_return.append(prune_event(e)) + + defer.returnValue(to_return) + # If there are no erased users then we can just return the given list + # of events without having to copy it. defer.returnValue(events) # Ok, so we're dealing with events that have non-trivial visibility @@ -361,7 +384,13 @@ def filter_events_for_server(store, server_name, events): for e_id, key_to_eid in iteritems(event_to_state_ids) } - defer.returnValue([ - redact_disallowed(e, event_to_state[e.event_id]) - for e in events - ]) + to_return = [] + for e in events: + erased = is_sender_erased(e, erased_senders) + visible = check_event_is_visible(e, event_to_state[e.event_id]) + if visible and not erased: + to_return.append(e) + elif redact: + to_return.append(prune_event(e)) + + defer.returnValue(to_return) |