diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/app/client_reader.py | 12 | ||||
-rw-r--r-- | synapse/app/event_creator.py | 12 | ||||
-rw-r--r-- | synapse/app/federation_reader.py | 12 | ||||
-rw-r--r-- | synapse/app/federation_sender.py | 18 | ||||
-rw-r--r-- | synapse/app/frontend_proxy.py | 12 | ||||
-rw-r--r-- | synapse/app/media_repository.py | 12 | ||||
-rw-r--r-- | synapse/app/user_dir.py | 18 | ||||
-rw-r--r-- | synapse/handlers/room.py | 5 | ||||
-rw-r--r-- | synapse/handlers/room_member.py | 85 | ||||
-rw-r--r-- | synapse/http/federation/matrix_federation_agent.py | 140 | ||||
-rw-r--r-- | synapse/rest/client/v1/room.py | 8 | ||||
-rw-r--r-- | synapse/storage/_base.py | 10 | ||||
-rw-r--r-- | synapse/storage/background_updates.py | 2 | ||||
-rw-r--r-- | synapse/storage/engines/sqlite.py | 10 | ||||
-rw-r--r-- | synapse/storage/monthly_active_users.py | 12 |
15 files changed, 270 insertions, 98 deletions
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index 76aed8c60a..f8a417cb60 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -164,23 +164,23 @@ def start(config_options): database_engine = create_engine(config.database_config) - tls_server_context_factory = context_factory.ServerContextFactory(config) - tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config) - ss = ClientReaderServer( config.server_name, db_config=config.database_config, - tls_server_context_factory=tls_server_context_factory, - tls_client_options_factory=tls_client_options_factory, config=config, version_string="Synapse/" + get_version_string(synapse), database_engine=database_engine, ) ss.setup() - ss.start_listening(config.worker_listeners) def start(): + ss.config.read_certificate_from_disk() + ss.tls_server_context_factory = context_factory.ServerContextFactory(config) + ss.tls_client_options_factory = context_factory.ClientTLSOptionsFactory( + config + ) + ss.start_listening(config.worker_listeners) ss.get_datastore().start_profiling() reactor.callWhenRunning(start) diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index e4a68715aa..656e0edc0f 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -185,23 +185,23 @@ def start(config_options): database_engine = create_engine(config.database_config) - tls_server_context_factory = context_factory.ServerContextFactory(config) - tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config) - ss = EventCreatorServer( config.server_name, db_config=config.database_config, - tls_server_context_factory=tls_server_context_factory, - tls_client_options_factory=tls_client_options_factory, config=config, version_string="Synapse/" + get_version_string(synapse), database_engine=database_engine, ) ss.setup() - ss.start_listening(config.worker_listeners) def start(): + ss.config.read_certificate_from_disk() + ss.tls_server_context_factory = context_factory.ServerContextFactory(config) + ss.tls_client_options_factory = context_factory.ClientTLSOptionsFactory( + config + ) + ss.start_listening(config.worker_listeners) ss.get_datastore().start_profiling() reactor.callWhenRunning(start) diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 228a297fb8..3de2715132 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -151,23 +151,23 @@ def start(config_options): database_engine = create_engine(config.database_config) - tls_server_context_factory = context_factory.ServerContextFactory(config) - tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config) - ss = FederationReaderServer( config.server_name, db_config=config.database_config, - tls_server_context_factory=tls_server_context_factory, - tls_client_options_factory=tls_client_options_factory, config=config, version_string="Synapse/" + get_version_string(synapse), database_engine=database_engine, ) ss.setup() - ss.start_listening(config.worker_listeners) def start(): + ss.config.read_certificate_from_disk() + ss.tls_server_context_factory = context_factory.ServerContextFactory(config) + ss.tls_client_options_factory = context_factory.ClientTLSOptionsFactory( + config + ) + ss.start_listening(config.worker_listeners) ss.get_datastore().start_profiling() reactor.callWhenRunning(start) diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index e9a99d76e1..d944e0517f 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -183,24 +183,24 @@ def start(config_options): # Force the pushers to start since they will be disabled in the main config config.send_federation = True - tls_server_context_factory = context_factory.ServerContextFactory(config) - tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config) - - ps = FederationSenderServer( + ss = FederationSenderServer( config.server_name, db_config=config.database_config, - tls_server_context_factory=tls_server_context_factory, - tls_client_options_factory=tls_client_options_factory, config=config, version_string="Synapse/" + get_version_string(synapse), database_engine=database_engine, ) - ps.setup() - ps.start_listening(config.worker_listeners) + ss.setup() def start(): - ps.get_datastore().start_profiling() + ss.config.read_certificate_from_disk() + ss.tls_server_context_factory = context_factory.ServerContextFactory(config) + ss.tls_client_options_factory = context_factory.ClientTLSOptionsFactory( + config + ) + ss.start_listening(config.worker_listeners) + ss.get_datastore().start_profiling() reactor.callWhenRunning(start) _base.start_worker_reactor("synapse-federation-sender", config) diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index f5c61dec5b..d9ef6edc3c 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -241,23 +241,23 @@ def start(config_options): database_engine = create_engine(config.database_config) - tls_server_context_factory = context_factory.ServerContextFactory(config) - tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config) - ss = FrontendProxyServer( config.server_name, db_config=config.database_config, - tls_server_context_factory=tls_server_context_factory, - tls_client_options_factory=tls_client_options_factory, config=config, version_string="Synapse/" + get_version_string(synapse), database_engine=database_engine, ) ss.setup() - ss.start_listening(config.worker_listeners) def start(): + ss.config.read_certificate_from_disk() + ss.tls_server_context_factory = context_factory.ServerContextFactory(config) + ss.tls_client_options_factory = context_factory.ClientTLSOptionsFactory( + config + ) + ss.start_listening(config.worker_listeners) ss.get_datastore().start_profiling() reactor.callWhenRunning(start) diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index acc0487adc..4ecf64031b 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -151,23 +151,23 @@ def start(config_options): database_engine = create_engine(config.database_config) - tls_server_context_factory = context_factory.ServerContextFactory(config) - tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config) - ss = MediaRepositoryServer( config.server_name, db_config=config.database_config, - tls_server_context_factory=tls_server_context_factory, - tls_client_options_factory=tls_client_options_factory, config=config, version_string="Synapse/" + get_version_string(synapse), database_engine=database_engine, ) ss.setup() - ss.start_listening(config.worker_listeners) def start(): + ss.config.read_certificate_from_disk() + ss.tls_server_context_factory = context_factory.ServerContextFactory(config) + ss.tls_client_options_factory = context_factory.ClientTLSOptionsFactory( + config + ) + ss.start_listening(config.worker_listeners) ss.get_datastore().start_profiling() reactor.callWhenRunning(start) diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index 0a5f62b509..176d55a783 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -211,24 +211,24 @@ def start(config_options): # Force the pushers to start since they will be disabled in the main config config.update_user_directory = True - tls_server_context_factory = context_factory.ServerContextFactory(config) - tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config) - - ps = UserDirectoryServer( + ss = UserDirectoryServer( config.server_name, db_config=config.database_config, - tls_server_context_factory=tls_server_context_factory, - tls_client_options_factory=tls_client_options_factory, config=config, version_string="Synapse/" + get_version_string(synapse), database_engine=database_engine, ) - ps.setup() - ps.start_listening(config.worker_listeners) + ss.setup() def start(): - ps.get_datastore().start_profiling() + ss.config.read_certificate_from_disk() + ss.tls_server_context_factory = context_factory.ServerContextFactory(config) + ss.tls_client_options_factory = context_factory.ClientTLSOptionsFactory( + config + ) + ss.start_listening(config.worker_listeners) + ss.get_datastore().start_profiling() reactor.callWhenRunning(start) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 19b4ee35d2..13ba9291b0 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -128,7 +128,7 @@ class RoomCreationHandler(BaseHandler): old_room_version, tombstone_event, tombstone_context, ) - yield self.clone_exiting_room( + yield self.clone_existing_room( requester, old_room_id=old_room_id, new_room_id=new_room_id, @@ -233,7 +233,7 @@ class RoomCreationHandler(BaseHandler): ) @defer.inlineCallbacks - def clone_exiting_room( + def clone_existing_room( self, requester, old_room_id, new_room_id, new_room_version, tombstone_event_id, ): @@ -265,6 +265,7 @@ class RoomCreationHandler(BaseHandler): initial_state = dict() + # Replicate relevant room events types_to_copy = ( (EventTypes.JoinRules, ""), (EventTypes.Name, ""), diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 07fd3e82fc..9ed5a05cca 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -63,7 +63,7 @@ class RoomMemberHandler(object): self.directory_handler = hs.get_handlers().directory_handler self.registration_handler = hs.get_handlers().registration_handler self.profile_handler = hs.get_profile_handler() - self.event_creation_hander = hs.get_event_creation_handler() + self.event_creation_handler = hs.get_event_creation_handler() self.member_linearizer = Linearizer(name="member") @@ -161,6 +161,8 @@ class RoomMemberHandler(object): ratelimit=True, content=None, ): + user_id = target.to_string() + if content is None: content = {} @@ -168,14 +170,14 @@ class RoomMemberHandler(object): if requester.is_guest: content["kind"] = "guest" - event, context = yield self.event_creation_hander.create_event( + event, context = yield self.event_creation_handler.create_event( requester, { "type": EventTypes.Member, "content": content, "room_id": room_id, "sender": requester.user.to_string(), - "state_key": target.to_string(), + "state_key": user_id, # For backwards compatibility: "membership": membership, @@ -186,14 +188,14 @@ class RoomMemberHandler(object): ) # Check if this event matches the previous membership event for the user. - duplicate = yield self.event_creation_hander.deduplicate_state_event( + duplicate = yield self.event_creation_handler.deduplicate_state_event( event, context, ) if duplicate is not None: # Discard the new event since this membership change is a no-op. defer.returnValue(duplicate) - yield self.event_creation_hander.handle_new_client_event( + yield self.event_creation_handler.handle_new_client_event( requester, event, context, @@ -204,12 +206,12 @@ class RoomMemberHandler(object): prev_state_ids = yield context.get_prev_state_ids(self.store) prev_member_event_id = prev_state_ids.get( - (EventTypes.Member, target.to_string()), + (EventTypes.Member, user_id), None ) if event.membership == Membership.JOIN: - # Only fire user_joined_room if the user has acutally joined the + # Only fire user_joined_room if the user has actually joined the # room. Don't bother if the user is just changing their profile # info. newly_joined = True @@ -218,6 +220,18 @@ class RoomMemberHandler(object): newly_joined = prev_member_event.membership != Membership.JOIN if newly_joined: yield self._user_joined_room(target, room_id) + + # Copy over direct message status and room tags if this is a join + # on an upgraded room + + # Check if this is an upgraded room + predecessor = yield self.store.get_room_predecessor(room_id) + + if predecessor: + # It is an upgraded room. Copy over old tags + self.copy_room_tags_and_direct_to_room( + predecessor["room_id"], room_id, user_id, + ) elif event.membership == Membership.LEAVE: if prev_member_event_id: prev_member_event = yield self.store.get_event(prev_member_event_id) @@ -227,6 +241,55 @@ class RoomMemberHandler(object): defer.returnValue(event) @defer.inlineCallbacks + def copy_room_tags_and_direct_to_room( + self, + old_room_id, + new_room_id, + user_id, + ): + """Copies the tags and direct room state from one room to another. + + Args: + old_room_id (str) + new_room_id (str) + user_id (str) + + Returns: + Deferred[None] + """ + # Retrieve user account data for predecessor room + user_account_data, _ = yield self.store.get_account_data_for_user( + user_id, + ) + + # Copy direct message state if applicable + direct_rooms = user_account_data.get("m.direct", {}) + + # Check which key this room is under + if isinstance(direct_rooms, dict): + for key, room_id_list in direct_rooms.items(): + if old_room_id in room_id_list and new_room_id not in room_id_list: + # Add new room_id to this key + direct_rooms[key].append(new_room_id) + + # Save back to user's m.direct account data + yield self.store.add_account_data_for_user( + user_id, "m.direct", direct_rooms, + ) + break + + # Copy room tags if applicable + room_tags = yield self.store.get_tags_for_room( + user_id, old_room_id, + ) + + # Copy each room tag to the new room + for tag, tag_content in room_tags.items(): + yield self.store.add_tag_to_room( + user_id, new_room_id, tag, tag_content + ) + + @defer.inlineCallbacks def update_membership( self, requester, @@ -493,7 +556,7 @@ class RoomMemberHandler(object): else: requester = synapse.types.create_requester(target_user) - prev_event = yield self.event_creation_hander.deduplicate_state_event( + prev_event = yield self.event_creation_handler.deduplicate_state_event( event, context, ) if prev_event is not None: @@ -513,7 +576,7 @@ class RoomMemberHandler(object): if is_blocked: raise SynapseError(403, "This room has been blocked on this server") - yield self.event_creation_hander.handle_new_client_event( + yield self.event_creation_handler.handle_new_client_event( requester, event, context, @@ -527,7 +590,7 @@ class RoomMemberHandler(object): ) if event.membership == Membership.JOIN: - # Only fire user_joined_room if the user has acutally joined the + # Only fire user_joined_room if the user has actually joined the # room. Don't bother if the user is just changing their profile # info. newly_joined = True @@ -755,7 +818,7 @@ class RoomMemberHandler(object): ) ) - yield self.event_creation_hander.create_and_send_nonmember_event( + yield self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.ThirdPartyInvite, diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 1788e9a34a..4a6f634c8b 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -14,6 +14,8 @@ # limitations under the License. import logging +import attr +from netaddr import IPAddress from zope.interface import implementer from twisted.internet import defer @@ -22,7 +24,6 @@ from twisted.web.client import URI, Agent, HTTPConnectionPool from twisted.web.http_headers import Headers from twisted.web.iweb import IAgent -from synapse.http.endpoint import parse_server_name from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list from synapse.util.logcontext import make_deferred_yieldable @@ -86,29 +87,20 @@ class MatrixFederationAgent(object): response from being received (including problems that prevent the request from being sent). """ + parsed_uri = URI.fromBytes(uri, defaultPort=-1) + res = yield self._route_matrix_uri(parsed_uri) - parsed_uri = URI.fromBytes(uri) - server_name_bytes = parsed_uri.netloc - host, port = parse_server_name(server_name_bytes.decode("ascii")) - + # set up the TLS connection params + # # XXX disabling TLS is really only supported here for the benefit of the # unit tests. We should make the UTs cope with TLS rather than having to make # the code support the unit tests. if self._tls_client_options_factory is None: tls_options = None else: - tls_options = self._tls_client_options_factory.get_options(host) - - if port is not None: - target = (host, port) - else: - service_name = b"_matrix._tcp.%s" % (server_name_bytes, ) - server_list = yield self._srv_resolver.resolve_service(service_name) - if not server_list: - target = (host, 8448) - logger.debug("No SRV record for %s, using %s", host, target) - else: - target = pick_server_from_list(server_list) + tls_options = self._tls_client_options_factory.get_options( + res.tls_server_name.decode("ascii") + ) # make sure that the Host header is set correctly if headers is None: @@ -117,13 +109,13 @@ class MatrixFederationAgent(object): headers = headers.copy() if not headers.hasHeader(b'host'): - headers.addRawHeader(b'host', server_name_bytes) + headers.addRawHeader(b'host', res.host_header) class EndpointFactory(object): @staticmethod def endpointForURI(_uri): - logger.info("Connecting to %s:%s", target[0], target[1]) - ep = HostnameEndpoint(self._reactor, host=target[0], port=target[1]) + logger.info("Connecting to %s:%s", res.target_host, res.target_port) + ep = HostnameEndpoint(self._reactor, res.target_host, res.target_port) if tls_options is not None: ep = wrapClientTLS(tls_options, ep) return ep @@ -133,3 +125,111 @@ class MatrixFederationAgent(object): agent.request(method, uri, headers, bodyProducer) ) defer.returnValue(res) + + @defer.inlineCallbacks + def _route_matrix_uri(self, parsed_uri): + """Helper for `request`: determine the routing for a Matrix URI + + Args: + parsed_uri (twisted.web.client.URI): uri to route. Note that it should be + parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1 + if there is no explicit port given. + + Returns: + Deferred[_RoutingResult] + """ + # check for an IP literal + try: + ip_address = IPAddress(parsed_uri.host.decode("ascii")) + except Exception: + # not an IP address + ip_address = None + + if ip_address: + port = parsed_uri.port + if port == -1: + port = 8448 + defer.returnValue(_RoutingResult( + host_header=parsed_uri.netloc, + tls_server_name=parsed_uri.host, + target_host=parsed_uri.host, + target_port=port, + )) + + if parsed_uri.port != -1: + # there is an explicit port + defer.returnValue(_RoutingResult( + host_header=parsed_uri.netloc, + tls_server_name=parsed_uri.host, + target_host=parsed_uri.host, + target_port=parsed_uri.port, + )) + + # try a SRV lookup + service_name = b"_matrix._tcp.%s" % (parsed_uri.host,) + server_list = yield self._srv_resolver.resolve_service(service_name) + + if not server_list: + target_host = parsed_uri.host + port = 8448 + logger.debug( + "No SRV record for %s, using %s:%i", + parsed_uri.host.decode("ascii"), target_host.decode("ascii"), port, + ) + else: + target_host, port = pick_server_from_list(server_list) + logger.debug( + "Picked %s:%i from SRV records for %s", + target_host.decode("ascii"), port, parsed_uri.host.decode("ascii"), + ) + + defer.returnValue(_RoutingResult( + host_header=parsed_uri.netloc, + tls_server_name=parsed_uri.host, + target_host=target_host, + target_port=port, + )) + + +@attr.s +class _RoutingResult(object): + """The result returned by `_route_matrix_uri`. + + Contains the parameters needed to direct a federation connection to a particular + server. + + Where a SRV record points to several servers, this object contains a single server + chosen from the list. + """ + + host_header = attr.ib() + """ + The value we should assign to the Host header (host:port from the matrix + URI, or .well-known). + + :type: bytes + """ + + tls_server_name = attr.ib() + """ + The server name we should set in the SNI (typically host, without port, from the + matrix URI or .well-known) + + :type: bytes + """ + + target_host = attr.ib() + """ + The hostname (or IP literal) we should route the TCP connection to (the target of the + SRV record, or the hostname from the URL/.well-known) + + :type: bytes + """ + + target_port = attr.ib() + """ + The port we should route the TCP connection to (the target of the SRV record, or + the port from the URL/.well-known, or 8448) + + :type: int + """ diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index fcfe7857f6..48da4d557f 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -89,7 +89,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): def __init__(self, hs): super(RoomStateEventRestServlet, self).__init__(hs) self.handlers = hs.get_handlers() - self.event_creation_hander = hs.get_event_creation_handler() + self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() self.message_handler = hs.get_message_handler() @@ -172,7 +172,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): content=content, ) else: - event = yield self.event_creation_hander.create_and_send_nonmember_event( + event = yield self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict, txn_id=txn_id, @@ -189,7 +189,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet): def __init__(self, hs): super(RoomSendEventRestServlet, self).__init__(hs) - self.event_creation_hander = hs.get_event_creation_handler() + self.event_creation_handler = hs.get_event_creation_handler() def register(self, http_server): # /rooms/$roomid/send/$event_type[/$txn_id] @@ -211,7 +211,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet): if b'ts' in request.args and requester.app_service: event_dict['origin_server_ts'] = parse_integer(request, "ts", 0) - event = yield self.event_creation_hander.create_and_send_nonmember_event( + event = yield self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict, txn_id=txn_id, diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index f62f70b9f1..5109bc3e2e 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -27,7 +27,7 @@ from twisted.internet import defer from synapse.api.errors import StoreError from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage.engines import PostgresEngine +from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.util.caches.descriptors import Cache from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.stringutils import exception_to_unicode @@ -196,6 +196,12 @@ class SQLBaseStore(object): # A set of tables that are not safe to use native upserts in. self._unsafe_to_upsert_tables = {"user_ips"} + # We add the user_directory_search table to the blacklist on SQLite + # because the existing search table does not have an index, making it + # unsafe to use native upserts. + if isinstance(self.database_engine, Sqlite3Engine): + self._unsafe_to_upsert_tables.add("user_directory_search") + if self.database_engine.can_native_upsert: # Check ASAP (and then later, every 1s) to see if we have finished # background updates of tables that aren't safe to update. @@ -230,7 +236,7 @@ class SQLBaseStore(object): self._unsafe_to_upsert_tables.discard("user_ips") # If there's any tables left to check, reschedule to run. - if self._unsafe_to_upsert_tables: + if self.updates: self._clock.call_later( 15.0, run_as_background_process, diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 5fe1ca2de7..60cdc884e6 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -240,7 +240,7 @@ class BackgroundUpdateStore(SQLBaseStore): * An integer count of the number of items to update in this batch. The handler should return a deferred integer count of items updated. - The hander is responsible for updating the progress of the update. + The handler is responsible for updating the progress of the update. Args: update_name(str): The name of the update that this code handles. diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 31b8449ca1..059ab81055 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -33,14 +33,10 @@ class Sqlite3Engine(object): @property def can_native_upsert(self): """ - Do we support native UPSERTs? + Do we support native UPSERTs? This requires SQLite3 3.24+, plus some + more work we haven't done yet to tell what was inserted vs updated. """ - # SQLite3 3.24+ supports them, but empirically the unit tests don't work - # when its enabled. - # FIXME: Figure out what is wrong so we can re-enable native upserts - - # return self.module.sqlite_version_info >= (3, 24, 0) - return False + return self.module.sqlite_version_info >= (3, 24, 0) def check_database(self, txn): pass diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index d6fc8edd4c..9e7e09b8c1 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -197,15 +197,21 @@ class MonthlyActiveUsersStore(SQLBaseStore): if is_support: return - is_insert = yield self.runInteraction( + yield self.runInteraction( "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id ) - if is_insert: - self.user_last_seen_monthly_active.invalidate((user_id,)) + user_in_mau = self.user_last_seen_monthly_active.cache.get( + (user_id,), + None, + update_metrics=False + ) + if user_in_mau is None: self.get_monthly_active_count.invalidate(()) + self.user_last_seen_monthly_active.invalidate((user_id,)) + def upsert_monthly_active_user_txn(self, txn, user_id): """Updates or inserts monthly active user member |