diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index 76aed8c60a..f8a417cb60 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -164,23 +164,23 @@ def start(config_options):
database_engine = create_engine(config.database_config)
- tls_server_context_factory = context_factory.ServerContextFactory(config)
- tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
-
ss = ClientReaderServer(
config.server_name,
db_config=config.database_config,
- tls_server_context_factory=tls_server_context_factory,
- tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
- ss.start_listening(config.worker_listeners)
def start():
+ ss.config.read_certificate_from_disk()
+ ss.tls_server_context_factory = context_factory.ServerContextFactory(config)
+ ss.tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
+ config
+ )
+ ss.start_listening(config.worker_listeners)
ss.get_datastore().start_profiling()
reactor.callWhenRunning(start)
diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py
index e4a68715aa..656e0edc0f 100644
--- a/synapse/app/event_creator.py
+++ b/synapse/app/event_creator.py
@@ -185,23 +185,23 @@ def start(config_options):
database_engine = create_engine(config.database_config)
- tls_server_context_factory = context_factory.ServerContextFactory(config)
- tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
-
ss = EventCreatorServer(
config.server_name,
db_config=config.database_config,
- tls_server_context_factory=tls_server_context_factory,
- tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
- ss.start_listening(config.worker_listeners)
def start():
+ ss.config.read_certificate_from_disk()
+ ss.tls_server_context_factory = context_factory.ServerContextFactory(config)
+ ss.tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
+ config
+ )
+ ss.start_listening(config.worker_listeners)
ss.get_datastore().start_profiling()
reactor.callWhenRunning(start)
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index 228a297fb8..3de2715132 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -151,23 +151,23 @@ def start(config_options):
database_engine = create_engine(config.database_config)
- tls_server_context_factory = context_factory.ServerContextFactory(config)
- tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
-
ss = FederationReaderServer(
config.server_name,
db_config=config.database_config,
- tls_server_context_factory=tls_server_context_factory,
- tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
- ss.start_listening(config.worker_listeners)
def start():
+ ss.config.read_certificate_from_disk()
+ ss.tls_server_context_factory = context_factory.ServerContextFactory(config)
+ ss.tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
+ config
+ )
+ ss.start_listening(config.worker_listeners)
ss.get_datastore().start_profiling()
reactor.callWhenRunning(start)
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index e9a99d76e1..d944e0517f 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -183,24 +183,24 @@ def start(config_options):
# Force the pushers to start since they will be disabled in the main config
config.send_federation = True
- tls_server_context_factory = context_factory.ServerContextFactory(config)
- tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
-
- ps = FederationSenderServer(
+ ss = FederationSenderServer(
config.server_name,
db_config=config.database_config,
- tls_server_context_factory=tls_server_context_factory,
- tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
- ps.setup()
- ps.start_listening(config.worker_listeners)
+ ss.setup()
def start():
- ps.get_datastore().start_profiling()
+ ss.config.read_certificate_from_disk()
+ ss.tls_server_context_factory = context_factory.ServerContextFactory(config)
+ ss.tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
+ config
+ )
+ ss.start_listening(config.worker_listeners)
+ ss.get_datastore().start_profiling()
reactor.callWhenRunning(start)
_base.start_worker_reactor("synapse-federation-sender", config)
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
index f5c61dec5b..d9ef6edc3c 100644
--- a/synapse/app/frontend_proxy.py
+++ b/synapse/app/frontend_proxy.py
@@ -241,23 +241,23 @@ def start(config_options):
database_engine = create_engine(config.database_config)
- tls_server_context_factory = context_factory.ServerContextFactory(config)
- tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
-
ss = FrontendProxyServer(
config.server_name,
db_config=config.database_config,
- tls_server_context_factory=tls_server_context_factory,
- tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
- ss.start_listening(config.worker_listeners)
def start():
+ ss.config.read_certificate_from_disk()
+ ss.tls_server_context_factory = context_factory.ServerContextFactory(config)
+ ss.tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
+ config
+ )
+ ss.start_listening(config.worker_listeners)
ss.get_datastore().start_profiling()
reactor.callWhenRunning(start)
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index acc0487adc..4ecf64031b 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -151,23 +151,23 @@ def start(config_options):
database_engine = create_engine(config.database_config)
- tls_server_context_factory = context_factory.ServerContextFactory(config)
- tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
-
ss = MediaRepositoryServer(
config.server_name,
db_config=config.database_config,
- tls_server_context_factory=tls_server_context_factory,
- tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
- ss.start_listening(config.worker_listeners)
def start():
+ ss.config.read_certificate_from_disk()
+ ss.tls_server_context_factory = context_factory.ServerContextFactory(config)
+ ss.tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
+ config
+ )
+ ss.start_listening(config.worker_listeners)
ss.get_datastore().start_profiling()
reactor.callWhenRunning(start)
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index 0a5f62b509..176d55a783 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -211,24 +211,24 @@ def start(config_options):
# Force the pushers to start since they will be disabled in the main config
config.update_user_directory = True
- tls_server_context_factory = context_factory.ServerContextFactory(config)
- tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
-
- ps = UserDirectoryServer(
+ ss = UserDirectoryServer(
config.server_name,
db_config=config.database_config,
- tls_server_context_factory=tls_server_context_factory,
- tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
- ps.setup()
- ps.start_listening(config.worker_listeners)
+ ss.setup()
def start():
- ps.get_datastore().start_profiling()
+ ss.config.read_certificate_from_disk()
+ ss.tls_server_context_factory = context_factory.ServerContextFactory(config)
+ ss.tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
+ config
+ )
+ ss.start_listening(config.worker_listeners)
+ ss.get_datastore().start_profiling()
reactor.callWhenRunning(start)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 19b4ee35d2..13ba9291b0 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -128,7 +128,7 @@ class RoomCreationHandler(BaseHandler):
old_room_version, tombstone_event, tombstone_context,
)
- yield self.clone_exiting_room(
+ yield self.clone_existing_room(
requester,
old_room_id=old_room_id,
new_room_id=new_room_id,
@@ -233,7 +233,7 @@ class RoomCreationHandler(BaseHandler):
)
@defer.inlineCallbacks
- def clone_exiting_room(
+ def clone_existing_room(
self, requester, old_room_id, new_room_id, new_room_version,
tombstone_event_id,
):
@@ -265,6 +265,7 @@ class RoomCreationHandler(BaseHandler):
initial_state = dict()
+ # Replicate relevant room events
types_to_copy = (
(EventTypes.JoinRules, ""),
(EventTypes.Name, ""),
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 07fd3e82fc..9ed5a05cca 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -63,7 +63,7 @@ class RoomMemberHandler(object):
self.directory_handler = hs.get_handlers().directory_handler
self.registration_handler = hs.get_handlers().registration_handler
self.profile_handler = hs.get_profile_handler()
- self.event_creation_hander = hs.get_event_creation_handler()
+ self.event_creation_handler = hs.get_event_creation_handler()
self.member_linearizer = Linearizer(name="member")
@@ -161,6 +161,8 @@ class RoomMemberHandler(object):
ratelimit=True,
content=None,
):
+ user_id = target.to_string()
+
if content is None:
content = {}
@@ -168,14 +170,14 @@ class RoomMemberHandler(object):
if requester.is_guest:
content["kind"] = "guest"
- event, context = yield self.event_creation_hander.create_event(
+ event, context = yield self.event_creation_handler.create_event(
requester,
{
"type": EventTypes.Member,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
- "state_key": target.to_string(),
+ "state_key": user_id,
# For backwards compatibility:
"membership": membership,
@@ -186,14 +188,14 @@ class RoomMemberHandler(object):
)
# Check if this event matches the previous membership event for the user.
- duplicate = yield self.event_creation_hander.deduplicate_state_event(
+ duplicate = yield self.event_creation_handler.deduplicate_state_event(
event, context,
)
if duplicate is not None:
# Discard the new event since this membership change is a no-op.
defer.returnValue(duplicate)
- yield self.event_creation_hander.handle_new_client_event(
+ yield self.event_creation_handler.handle_new_client_event(
requester,
event,
context,
@@ -204,12 +206,12 @@ class RoomMemberHandler(object):
prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_member_event_id = prev_state_ids.get(
- (EventTypes.Member, target.to_string()),
+ (EventTypes.Member, user_id),
None
)
if event.membership == Membership.JOIN:
- # Only fire user_joined_room if the user has acutally joined the
+ # Only fire user_joined_room if the user has actually joined the
# room. Don't bother if the user is just changing their profile
# info.
newly_joined = True
@@ -218,6 +220,18 @@ class RoomMemberHandler(object):
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
yield self._user_joined_room(target, room_id)
+
+ # Copy over direct message status and room tags if this is a join
+ # on an upgraded room
+
+ # Check if this is an upgraded room
+ predecessor = yield self.store.get_room_predecessor(room_id)
+
+ if predecessor:
+ # It is an upgraded room. Copy over old tags
+ self.copy_room_tags_and_direct_to_room(
+ predecessor["room_id"], room_id, user_id,
+ )
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
@@ -227,6 +241,55 @@ class RoomMemberHandler(object):
defer.returnValue(event)
@defer.inlineCallbacks
+ def copy_room_tags_and_direct_to_room(
+ self,
+ old_room_id,
+ new_room_id,
+ user_id,
+ ):
+ """Copies the tags and direct room state from one room to another.
+
+ Args:
+ old_room_id (str)
+ new_room_id (str)
+ user_id (str)
+
+ Returns:
+ Deferred[None]
+ """
+ # Retrieve user account data for predecessor room
+ user_account_data, _ = yield self.store.get_account_data_for_user(
+ user_id,
+ )
+
+ # Copy direct message state if applicable
+ direct_rooms = user_account_data.get("m.direct", {})
+
+ # Check which key this room is under
+ if isinstance(direct_rooms, dict):
+ for key, room_id_list in direct_rooms.items():
+ if old_room_id in room_id_list and new_room_id not in room_id_list:
+ # Add new room_id to this key
+ direct_rooms[key].append(new_room_id)
+
+ # Save back to user's m.direct account data
+ yield self.store.add_account_data_for_user(
+ user_id, "m.direct", direct_rooms,
+ )
+ break
+
+ # Copy room tags if applicable
+ room_tags = yield self.store.get_tags_for_room(
+ user_id, old_room_id,
+ )
+
+ # Copy each room tag to the new room
+ for tag, tag_content in room_tags.items():
+ yield self.store.add_tag_to_room(
+ user_id, new_room_id, tag, tag_content
+ )
+
+ @defer.inlineCallbacks
def update_membership(
self,
requester,
@@ -493,7 +556,7 @@ class RoomMemberHandler(object):
else:
requester = synapse.types.create_requester(target_user)
- prev_event = yield self.event_creation_hander.deduplicate_state_event(
+ prev_event = yield self.event_creation_handler.deduplicate_state_event(
event, context,
)
if prev_event is not None:
@@ -513,7 +576,7 @@ class RoomMemberHandler(object):
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
- yield self.event_creation_hander.handle_new_client_event(
+ yield self.event_creation_handler.handle_new_client_event(
requester,
event,
context,
@@ -527,7 +590,7 @@ class RoomMemberHandler(object):
)
if event.membership == Membership.JOIN:
- # Only fire user_joined_room if the user has acutally joined the
+ # Only fire user_joined_room if the user has actually joined the
# room. Don't bother if the user is just changing their profile
# info.
newly_joined = True
@@ -755,7 +818,7 @@ class RoomMemberHandler(object):
)
)
- yield self.event_creation_hander.create_and_send_nonmember_event(
+ yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.ThirdPartyInvite,
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 1788e9a34a..4a6f634c8b 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -14,6 +14,8 @@
# limitations under the License.
import logging
+import attr
+from netaddr import IPAddress
from zope.interface import implementer
from twisted.internet import defer
@@ -22,7 +24,6 @@ from twisted.web.client import URI, Agent, HTTPConnectionPool
from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent
-from synapse.http.endpoint import parse_server_name
from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
from synapse.util.logcontext import make_deferred_yieldable
@@ -86,29 +87,20 @@ class MatrixFederationAgent(object):
response from being received (including problems that prevent the request
from being sent).
"""
+ parsed_uri = URI.fromBytes(uri, defaultPort=-1)
+ res = yield self._route_matrix_uri(parsed_uri)
- parsed_uri = URI.fromBytes(uri)
- server_name_bytes = parsed_uri.netloc
- host, port = parse_server_name(server_name_bytes.decode("ascii"))
-
+ # set up the TLS connection params
+ #
# XXX disabling TLS is really only supported here for the benefit of the
# unit tests. We should make the UTs cope with TLS rather than having to make
# the code support the unit tests.
if self._tls_client_options_factory is None:
tls_options = None
else:
- tls_options = self._tls_client_options_factory.get_options(host)
-
- if port is not None:
- target = (host, port)
- else:
- service_name = b"_matrix._tcp.%s" % (server_name_bytes, )
- server_list = yield self._srv_resolver.resolve_service(service_name)
- if not server_list:
- target = (host, 8448)
- logger.debug("No SRV record for %s, using %s", host, target)
- else:
- target = pick_server_from_list(server_list)
+ tls_options = self._tls_client_options_factory.get_options(
+ res.tls_server_name.decode("ascii")
+ )
# make sure that the Host header is set correctly
if headers is None:
@@ -117,13 +109,13 @@ class MatrixFederationAgent(object):
headers = headers.copy()
if not headers.hasHeader(b'host'):
- headers.addRawHeader(b'host', server_name_bytes)
+ headers.addRawHeader(b'host', res.host_header)
class EndpointFactory(object):
@staticmethod
def endpointForURI(_uri):
- logger.info("Connecting to %s:%s", target[0], target[1])
- ep = HostnameEndpoint(self._reactor, host=target[0], port=target[1])
+ logger.info("Connecting to %s:%s", res.target_host, res.target_port)
+ ep = HostnameEndpoint(self._reactor, res.target_host, res.target_port)
if tls_options is not None:
ep = wrapClientTLS(tls_options, ep)
return ep
@@ -133,3 +125,111 @@ class MatrixFederationAgent(object):
agent.request(method, uri, headers, bodyProducer)
)
defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def _route_matrix_uri(self, parsed_uri):
+ """Helper for `request`: determine the routing for a Matrix URI
+
+ Args:
+ parsed_uri (twisted.web.client.URI): uri to route. Note that it should be
+ parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1
+ if there is no explicit port given.
+
+ Returns:
+ Deferred[_RoutingResult]
+ """
+ # check for an IP literal
+ try:
+ ip_address = IPAddress(parsed_uri.host.decode("ascii"))
+ except Exception:
+ # not an IP address
+ ip_address = None
+
+ if ip_address:
+ port = parsed_uri.port
+ if port == -1:
+ port = 8448
+ defer.returnValue(_RoutingResult(
+ host_header=parsed_uri.netloc,
+ tls_server_name=parsed_uri.host,
+ target_host=parsed_uri.host,
+ target_port=port,
+ ))
+
+ if parsed_uri.port != -1:
+ # there is an explicit port
+ defer.returnValue(_RoutingResult(
+ host_header=parsed_uri.netloc,
+ tls_server_name=parsed_uri.host,
+ target_host=parsed_uri.host,
+ target_port=parsed_uri.port,
+ ))
+
+ # try a SRV lookup
+ service_name = b"_matrix._tcp.%s" % (parsed_uri.host,)
+ server_list = yield self._srv_resolver.resolve_service(service_name)
+
+ if not server_list:
+ target_host = parsed_uri.host
+ port = 8448
+ logger.debug(
+ "No SRV record for %s, using %s:%i",
+ parsed_uri.host.decode("ascii"), target_host.decode("ascii"), port,
+ )
+ else:
+ target_host, port = pick_server_from_list(server_list)
+ logger.debug(
+ "Picked %s:%i from SRV records for %s",
+ target_host.decode("ascii"), port, parsed_uri.host.decode("ascii"),
+ )
+
+ defer.returnValue(_RoutingResult(
+ host_header=parsed_uri.netloc,
+ tls_server_name=parsed_uri.host,
+ target_host=target_host,
+ target_port=port,
+ ))
+
+
+@attr.s
+class _RoutingResult(object):
+ """The result returned by `_route_matrix_uri`.
+
+ Contains the parameters needed to direct a federation connection to a particular
+ server.
+
+ Where a SRV record points to several servers, this object contains a single server
+ chosen from the list.
+ """
+
+ host_header = attr.ib()
+ """
+ The value we should assign to the Host header (host:port from the matrix
+ URI, or .well-known).
+
+ :type: bytes
+ """
+
+ tls_server_name = attr.ib()
+ """
+ The server name we should set in the SNI (typically host, without port, from the
+ matrix URI or .well-known)
+
+ :type: bytes
+ """
+
+ target_host = attr.ib()
+ """
+ The hostname (or IP literal) we should route the TCP connection to (the target of the
+ SRV record, or the hostname from the URL/.well-known)
+
+ :type: bytes
+ """
+
+ target_port = attr.ib()
+ """
+ The port we should route the TCP connection to (the target of the SRV record, or
+ the port from the URL/.well-known, or 8448)
+
+ :type: int
+ """
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index fcfe7857f6..48da4d557f 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -89,7 +89,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomStateEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
- self.event_creation_hander = hs.get_event_creation_handler()
+ self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self.message_handler = hs.get_message_handler()
@@ -172,7 +172,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
content=content,
)
else:
- event = yield self.event_creation_hander.create_and_send_nonmember_event(
+ event = yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
event_dict,
txn_id=txn_id,
@@ -189,7 +189,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomSendEventRestServlet, self).__init__(hs)
- self.event_creation_hander = hs.get_event_creation_handler()
+ self.event_creation_handler = hs.get_event_creation_handler()
def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id]
@@ -211,7 +211,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
if b'ts' in request.args and requester.app_service:
event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
- event = yield self.event_creation_hander.create_and_send_nonmember_event(
+ event = yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
event_dict,
txn_id=txn_id,
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index f62f70b9f1..5109bc3e2e 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -27,7 +27,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage.engines import PostgresEngine
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.util.caches.descriptors import Cache
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.stringutils import exception_to_unicode
@@ -196,6 +196,12 @@ class SQLBaseStore(object):
# A set of tables that are not safe to use native upserts in.
self._unsafe_to_upsert_tables = {"user_ips"}
+ # We add the user_directory_search table to the blacklist on SQLite
+ # because the existing search table does not have an index, making it
+ # unsafe to use native upserts.
+ if isinstance(self.database_engine, Sqlite3Engine):
+ self._unsafe_to_upsert_tables.add("user_directory_search")
+
if self.database_engine.can_native_upsert:
# Check ASAP (and then later, every 1s) to see if we have finished
# background updates of tables that aren't safe to update.
@@ -230,7 +236,7 @@ class SQLBaseStore(object):
self._unsafe_to_upsert_tables.discard("user_ips")
# If there's any tables left to check, reschedule to run.
- if self._unsafe_to_upsert_tables:
+ if self.updates:
self._clock.call_later(
15.0,
run_as_background_process,
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 5fe1ca2de7..60cdc884e6 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -240,7 +240,7 @@ class BackgroundUpdateStore(SQLBaseStore):
* An integer count of the number of items to update in this batch.
The handler should return a deferred integer count of items updated.
- The hander is responsible for updating the progress of the update.
+ The handler is responsible for updating the progress of the update.
Args:
update_name(str): The name of the update that this code handles.
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 31b8449ca1..059ab81055 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -33,14 +33,10 @@ class Sqlite3Engine(object):
@property
def can_native_upsert(self):
"""
- Do we support native UPSERTs?
+ Do we support native UPSERTs? This requires SQLite3 3.24+, plus some
+ more work we haven't done yet to tell what was inserted vs updated.
"""
- # SQLite3 3.24+ supports them, but empirically the unit tests don't work
- # when its enabled.
- # FIXME: Figure out what is wrong so we can re-enable native upserts
-
- # return self.module.sqlite_version_info >= (3, 24, 0)
- return False
+ return self.module.sqlite_version_info >= (3, 24, 0)
def check_database(self, txn):
pass
diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py
index d6fc8edd4c..9e7e09b8c1 100644
--- a/synapse/storage/monthly_active_users.py
+++ b/synapse/storage/monthly_active_users.py
@@ -197,15 +197,21 @@ class MonthlyActiveUsersStore(SQLBaseStore):
if is_support:
return
- is_insert = yield self.runInteraction(
+ yield self.runInteraction(
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn,
user_id
)
- if is_insert:
- self.user_last_seen_monthly_active.invalidate((user_id,))
+ user_in_mau = self.user_last_seen_monthly_active.cache.get(
+ (user_id,),
+ None,
+ update_metrics=False
+ )
+ if user_in_mau is None:
self.get_monthly_active_count.invalidate(())
+ self.user_last_seen_monthly_active.invalidate((user_id,))
+
def upsert_monthly_active_user_txn(self, txn, user_id):
"""Updates or inserts monthly active user member
|