From 6e18805ac2906ea52d2374024392332af1a603b7 Mon Sep 17 00:00:00 2001 From: Glyph Date: Sun, 11 Dec 2016 01:44:02 -0800 Subject: IPv6 support for client.py This is an (untested) general sketch of how to use wrapClientTLS to implement TLS over IPv6, as well as faster connections over IPv4. --- synapse/http/client.py | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/synapse/http/client.py b/synapse/http/client.py index 3ec9bc7faf..c60e3c2ac0 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -386,26 +386,21 @@ class SpiderEndpointFactory(object): def endpointForURI(self, uri): logger.info("Getting endpoint for %s", uri.toBytes()) + if uri.scheme == "http": - return SpiderEndpoint( - reactor, uri.host, uri.port, self.blacklist, self.whitelist, - endpoint=TCP4ClientEndpoint, - endpoint_kw_args={ - 'timeout': 15 - }, - ) + endpoint_factory = HostnameEndpoint elif uri.scheme == "https": - tlsPolicy = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port) - return SpiderEndpoint( - reactor, uri.host, uri.port, self.blacklist, self.whitelist, - endpoint=SSL4ClientEndpoint, - endpoint_kw_args={ - 'sslContextFactory': tlsPolicy, - 'timeout': 15 - }, - ) + tlsCreator = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port) + def endpoint_factory(reactor, host, port, **kw): + return wrapClientTLS(tlsCreator, HostnameEndpoint(reactor, host, port, **kw) else: logger.warn("Can't get endpoint for unrecognised scheme %s", uri.scheme) + return None + return SpiderEndpoint( + reactor, uri.host, uri.port, self.blacklist, self.whitelist, + endpoint=endpoint_factory, endpoint_kw_args=dict(timeout=15), + ) + class SpiderHttpClient(SimpleHttpClient): -- cgit 1.4.1 From 9f07f4c5595b0eff19c9740c44803700b01b14af Mon Sep 17 00:00:00 2001 From: Glyph Date: Sun, 11 Dec 2016 01:46:43 -0800 Subject: IPv6 support for endpoint.py Similar to https://github.com/matrix-org/synapse/pull/1689, but for endpoint.py --- synapse/http/endpoint.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index 442696d393..5e2e428dbf 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint +from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet import defer from twisted.internet.error import ConnectError from twisted.names import client, dns @@ -58,11 +58,11 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, endpoint_kw_args.update(timeout=timeout) if ssl_context_factory is None: - transport_endpoint = TCP4ClientEndpoint + transport_endpoint = HostnameEndpoint default_port = 8008 else: - transport_endpoint = SSL4ClientEndpoint - endpoint_kw_args.update(sslContextFactory=ssl_context_factory) + def transport_endpoint(reactor, host, port): + return wrapClientTLS(ssl_context_factory, HostnameEndpoint(reactor, host, port)) default_port = 8448 if port is None: -- cgit 1.4.1 From d3bd94805f6ef68e75d8c2e39b8c97ea5ce88286 Mon Sep 17 00:00:00 2001 From: Johannes Löthberg Date: Mon, 12 Dec 2016 16:19:54 +0100 Subject: Fixup for #1689 and #1690 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Johannes Löthberg --- synapse/http/client.py | 11 +++++++---- synapse/http/endpoint.py | 10 ++++++---- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/synapse/http/client.py b/synapse/http/client.py index c60e3c2ac0..37988716e7 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -25,7 +25,7 @@ from synapse.http.endpoint import SpiderEndpoint from canonicaljson import encode_canonical_json from twisted.internet import defer, reactor, ssl, protocol, task -from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint +from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.web.client import ( BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent, readBody, PartialDownloadError, @@ -386,13 +386,16 @@ class SpiderEndpointFactory(object): def endpointForURI(self, uri): logger.info("Getting endpoint for %s", uri.toBytes()) - + if uri.scheme == "http": endpoint_factory = HostnameEndpoint elif uri.scheme == "https": tlsCreator = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port) + def endpoint_factory(reactor, host, port, **kw): - return wrapClientTLS(tlsCreator, HostnameEndpoint(reactor, host, port, **kw) + return wrapClientTLS( + tlsCreator, + HostnameEndpoint(reactor, host, port, **kw)) else: logger.warn("Can't get endpoint for unrecognised scheme %s", uri.scheme) return None @@ -400,7 +403,7 @@ class SpiderEndpointFactory(object): reactor, uri.host, uri.port, self.blacklist, self.whitelist, endpoint=endpoint_factory, endpoint_kw_args=dict(timeout=15), ) - + class SpiderHttpClient(SimpleHttpClient): diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index 5e2e428dbf..1c17a28406 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -61,8 +61,10 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, transport_endpoint = HostnameEndpoint default_port = 8008 else: - def transport_endpoint(reactor, host, port): - return wrapClientTLS(ssl_context_factory, HostnameEndpoint(reactor, host, port)) + def transport_endpoint(reactor, host, port, timeout): + return wrapClientTLS( + ssl_context_factory, + HostnameEndpoint(reactor, host, port, timeout=timeout)) default_port = 8448 if port is None: @@ -80,7 +82,7 @@ class SpiderEndpoint(object): Implements twisted.internet.interfaces.IStreamClientEndpoint. """ def __init__(self, reactor, host, port, blacklist, whitelist, - endpoint=TCP4ClientEndpoint, endpoint_kw_args={}): + endpoint=HostnameEndpoint, endpoint_kw_args={}): self.reactor = reactor self.host = host self.port = port @@ -118,7 +120,7 @@ class SRVClientEndpoint(object): """ def __init__(self, reactor, service, domain, protocol="tcp", - default_port=None, endpoint=TCP4ClientEndpoint, + default_port=None, endpoint=HostnameEndpoint, endpoint_kw_args={}): self.reactor = reactor self.service_name = "_%s._%s.%s" % (service, protocol, domain) -- cgit 1.4.1 From 0648e76979e4626cf3719edc5958eb4f170e0d1e Mon Sep 17 00:00:00 2001 From: Johannes Löthberg Date: Mon, 12 Dec 2016 18:40:39 +0100 Subject: Remove spurious newline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Apparently I just removed the spaces instead... Signed-off-by: Johannes Löthberg --- synapse/http/client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/synapse/http/client.py b/synapse/http/client.py index 37988716e7..ca2f770f5d 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -405,7 +405,6 @@ class SpiderEndpointFactory(object): ) - class SpiderHttpClient(SimpleHttpClient): """ Separate HTTP client for spidering arbitrary URLs. -- cgit 1.4.1 From b2f8642d3df30c88704e02422325b394880f66eb Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 16 Dec 2016 16:11:43 +0000 Subject: Cache network room list queries. --- synapse/handlers/room_list.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 667223df0c..19eebbd43f 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -62,17 +62,18 @@ class RoomListHandler(BaseHandler): appservice and network id to use an appservice specific one. Setting to None returns all public rooms across all lists. """ - if search_filter or (network_tuple and network_tuple.appservice_id is not None): + if search_filter: # We explicitly don't bother caching searches or requests for # appservice specific lists. return self._get_public_room_list( limit, since_token, search_filter, network_tuple=network_tuple, ) - result = self.response_cache.get((limit, since_token)) + key = (limit, since_token, network_tuple) + result = self.response_cache.get(key) if not result: result = self.response_cache.set( - (limit, since_token), + key, self._get_public_room_list( limit, since_token, network_tuple=network_tuple ) -- cgit 1.4.1 From 7dfd70fc834a14b7003beb220eebae6fead5dbf3 Mon Sep 17 00:00:00 2001 From: Johannes Löthberg Date: Sun, 18 Dec 2016 20:42:43 +0100 Subject: Add support for specifying multiple bind addresses MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Johannes Löthberg --- synapse/app/appservice.py | 54 +++++++++++++++++----------- synapse/app/client_reader.py | 54 +++++++++++++++++----------- synapse/app/federation_reader.py | 54 +++++++++++++++++----------- synapse/app/federation_sender.py | 54 +++++++++++++++++----------- synapse/app/homeserver.py | 76 ++++++++++++++++++++++++---------------- synapse/app/media_repository.py | 54 +++++++++++++++++----------- synapse/app/pusher.py | 65 +++++++++++++++++++++++----------- synapse/app/synchrotron.py | 54 +++++++++++++++++----------- 8 files changed, 294 insertions(+), 171 deletions(-) diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index dd9ee406a1..e24c1e1eda 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -76,7 +76,8 @@ class AppserviceServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", "") + bind_address = listener_config.get("bind_address", None) + bind_addresses = listener_config.get("bind_addresses", []) site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -85,16 +86,22 @@ class AppserviceServer(HomeServer): resources[METRICS_PREFIX] = MetricsResource(self) root_resource = create_resource_tree(resources, Resource()) - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=bind_address - ) + + if bind_address: + bind_addresses.append(bind_address) + + for address in bind_addresses: + reactor.listenTCP( + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + interface=address + ) + logger.info("Synapse appservice now listening on port %d", port) def start_listening(self, listeners): @@ -102,15 +109,22 @@ class AppserviceServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=listener.get("bind_address", '127.0.0.1') - ) + bind_address = listener.get("bind_address", None) + bind_addresses = listener.get("bind_addresses", []) + + if bind_address: + bind_addresses.append(bind_address) + + for address in bind_addresses: + reactor.listenTCP( + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ), + interface=address + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index 0086a2977e..305a82b664 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -90,7 +90,8 @@ class ClientReaderServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", "") + bind_address = listener_config.get("bind_address", None) + bind_addresses = listener_config.get("bind_addresses", []) site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -108,16 +109,22 @@ class ClientReaderServer(HomeServer): }) root_resource = create_resource_tree(resources, Resource()) - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=bind_address - ) + + if bind_address: + bind_addresses.append(bind_address) + + for address in bind_addresses: + reactor.listenTCP( + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + interface=address + ) + logger.info("Synapse client reader now listening on port %d", port) def start_listening(self, listeners): @@ -125,15 +132,22 @@ class ClientReaderServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=listener.get("bind_address", '127.0.0.1') - ) + bind_address = listener.get("bind_address", None) + bind_addresses = listener.get("bind_addresses", []) + + if bind_address: + bind_addresses.append(bind_address) + + for address in bind_addresses: + reactor.listenTCP( + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ), + interface=address + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index b5f59a9931..321dfc7cd5 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -86,7 +86,8 @@ class FederationReaderServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", "") + bind_address = listener_config.get("bind_address", None) + bind_addresses = listener_config.get("bind_addresses", []) site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -99,16 +100,22 @@ class FederationReaderServer(HomeServer): }) root_resource = create_resource_tree(resources, Resource()) - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=bind_address - ) + + if bind_address: + bind_addresses.append(bind_address) + + for address in bind_addresses: + reactor.listenTCP( + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + interface=address + ) + logger.info("Synapse federation reader now listening on port %d", port) def start_listening(self, listeners): @@ -116,15 +123,22 @@ class FederationReaderServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=listener.get("bind_address", '127.0.0.1') - ) + bind_address = listener.get("bind_address", None) + bind_addresses = listener.get("bind_addresses", []) + + if bind_address: + bind_addresses.append(bind_address) + + for address in bind_addresses: + reactor.listenTCP( + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ), + interface=address + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 80ea4c8062..8092fd316c 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -82,7 +82,8 @@ class FederationSenderServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", "") + bind_address = listener_config.get("bind_address", None) + bind_addresses = listener_config.get("bind_addresses", []) site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -91,16 +92,22 @@ class FederationSenderServer(HomeServer): resources[METRICS_PREFIX] = MetricsResource(self) root_resource = create_resource_tree(resources, Resource()) - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=bind_address - ) + + if bind_address: + bind_addresses.append(bind_address) + + for address in bind_addresses: + reactor.listenTCP( + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + interface=address + ) + logger.info("Synapse federation_sender now listening on port %d", port) def start_listening(self, listeners): @@ -108,15 +115,22 @@ class FederationSenderServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=listener.get("bind_address", '127.0.0.1') - ) + bind_address = listener.get("bind_address", None) + bind_addresses = listener.get("bind_addresses", []) + + if bind_address: + bind_addresses.append(bind_address) + + for address in bind_addresses: + reactor.listenTCP( + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ), + interface=address + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 54f35900f8..2d6becad1a 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -107,7 +107,8 @@ def build_resource_for_web_client(hs): class SynapseHomeServer(HomeServer): def _listener_http(self, config, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", "") + bind_address = listener_config.get("bind_address", None) + bind_addresses = listener_config.get("bind_addresses", []) tls = listener_config.get("tls", False) site_tag = listener_config.get("tag", port) @@ -173,29 +174,35 @@ class SynapseHomeServer(HomeServer): root_resource = Resource() root_resource = create_resource_tree(resources, root_resource) + + if bind_address: + bind_addresses.append(bind_address) + if tls: - reactor.listenSSL( - port, - SynapseSite( - "synapse.access.https.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - self.tls_server_context_factory, - interface=bind_address - ) + for address in bind_addresses: + reactor.listenSSL( + port, + SynapseSite( + "synapse.access.https.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + self.tls_server_context_factory, + interface=address + ) else: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=bind_address - ) + for address in bind_addresses: + reactor.listenTCP( + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + interface=address + ) logger.info("Synapse now listening on port %d", port) def start_listening(self): @@ -205,15 +212,22 @@ class SynapseHomeServer(HomeServer): if listener["type"] == "http": self._listener_http(config, listener) elif listener["type"] == "manhole": - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=listener.get("bind_address", '127.0.0.1') - ) + bind_address = listener.get("bind_address", None) + bind_addresses = listener.get("bind_addresses", []) + + if bind_address: + bind_addresses.append(bind_address) + + for address in bind_addresses: + reactor.listenTCP( + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ), + interface=address + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index 44c19a1bef..c121107245 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -87,7 +87,8 @@ class MediaRepositoryServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", "") + bind_address = listener_config.get("bind_address", None) + bind_addresses = listener_config.get("bind_addresses", []) site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -105,16 +106,22 @@ class MediaRepositoryServer(HomeServer): }) root_resource = create_resource_tree(resources, Resource()) - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=bind_address - ) + + if bind_address: + bind_addresses.append(bind_address) + + for address in bind_addresses: + reactor.listenTCP( + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + interface=address + ) + logger.info("Synapse media repository now listening on port %d", port) def start_listening(self, listeners): @@ -122,15 +129,22 @@ class MediaRepositoryServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=listener.get("bind_address", '127.0.0.1') - ) + bind_address = listener.get("bind_address", None) + bind_addresses = listener.get("bind_addresses", []) + + if bind_address: + bind_addresses.append(bind_address) + + for address in bind_addresses: + reactor.listenTCP( + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ), + interface=address + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index a0e765c54f..159850c44c 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -121,7 +121,8 @@ class PusherServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", "") + bind_address = listener_config.get("bind_address", None) + bind_addresses = listener_config.get("bind_addresses", []) site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -130,16 +131,33 @@ class PusherServer(HomeServer): resources[METRICS_PREFIX] = MetricsResource(self) root_resource = create_resource_tree(resources, Resource()) - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=bind_address - ) + + if bind_address: + bind_addresses.append(bind_address) + + for address in bind_addresses: + reactor.listenTCP( + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + interface=address + ) + else: + reactor.listenTCP( + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + interface=bind_address + ) + logger.info("Synapse pusher now listening on port %d", port) def start_listening(self, listeners): @@ -147,15 +165,22 @@ class PusherServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=listener.get("bind_address", '127.0.0.1') - ) + bind_address = listener.get("bind_address", None) + bind_addresses = listener.get("bind_addresses", []) + + if bind_address: + bind_addresses.append(bind_address) + + for address in bind_addresses: + reactor.listenTCP( + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ), + interface=address + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index bf1b995dc2..56143d0025 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -289,7 +289,8 @@ class SynchrotronServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", "") + bind_address = listener_config.get("bind_address", None) + bind_addresses = listener_config.get("bind_addresses", []) site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -310,16 +311,22 @@ class SynchrotronServer(HomeServer): }) root_resource = create_resource_tree(resources, Resource()) - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=bind_address - ) + + if bind_address: + bind_addresses.append(bind_address) + + for address in bind_addresses: + reactor.listenTCP( + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + interface=address + ) + logger.info("Synapse synchrotron now listening on port %d", port) def start_listening(self, listeners): @@ -327,15 +334,22 @@ class SynchrotronServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=listener.get("bind_address", '127.0.0.1') - ) + bind_address = listener.get("bind_address", None) + bind_addresses = listener.get("bind_addresses", []) + + if bind_address: + bind_addresses.append(bind_address) + + for address in bind_addresses: + reactor.listenTCP( + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ), + interface=address + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) -- cgit 1.4.1 From c95e9fff990722dbeb8bc7971640a517ea7f5fbb Mon Sep 17 00:00:00 2001 From: Johannes Löthberg Date: Sun, 18 Dec 2016 20:54:22 +0100 Subject: Make default homeserver config use bind_addresses MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Johannes Löthberg --- synapse/config/server.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/synapse/config/server.py b/synapse/config/server.py index 634d8e6fe5..1b9e10b527 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -155,9 +155,10 @@ class ServerConfig(Config): # The port to listen for HTTPS requests on. port: %(bind_port)s - # Local interface to listen on. - # The empty string will cause synapse to listen on all interfaces. - bind_address: '' + # Local addresses to listen on. + # This will listen on all IPv4 addresses by default. + bind_addresses: + - '0.0.0.0' # This is a 'http' listener, allows us to specify 'resources'. type: http @@ -188,7 +189,7 @@ class ServerConfig(Config): # For when matrix traffic passes through loadbalancer that unwraps TLS. - port: %(unsecure_port)s tls: false - bind_address: '' + bind_addresses: ['0.0.0.0'] type: http x_forwarded: false -- cgit 1.4.1 From 1859af9b2a904b560e7260d38147d80e72137c28 Mon Sep 17 00:00:00 2001 From: Johannes Löthberg Date: Sun, 18 Dec 2016 22:01:34 +0100 Subject: Update README to use `bind_addresses` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Johannes Löthberg --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 5ffcff22cd..ba21c52ae7 100644 --- a/README.rst +++ b/README.rst @@ -658,7 +658,7 @@ configuration might look like:: } } -You will also want to set ``bind_address: 127.0.0.1`` and ``x_forwarded: true`` +You will also want to set ``bind_addresses: ['127.0.0.1']`` and ``x_forwarded: true`` for port 8008 in ``homeserver.yaml`` to ensure that client IP addresses are recorded correctly. -- cgit 1.4.1 From f5cd5ebd7bd8582acd5805021c6718869f8519b1 Mon Sep 17 00:00:00 2001 From: Johannes Löthberg Date: Sun, 18 Dec 2016 23:14:32 +0100 Subject: Add IPv6 comment to default config MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Johannes Löthberg --- synapse/config/server.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/synapse/config/server.py b/synapse/config/server.py index 1b9e10b527..5e6b2a68a7 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -159,6 +159,10 @@ class ServerConfig(Config): # This will listen on all IPv4 addresses by default. bind_addresses: - '0.0.0.0' + # Uncomment to listen on all IPv6 interfaces + # N.B: On at least Linux this will also listen on all IPv4 + # addresses, so you will need to comment out the line above. + # - '::' # This is a 'http' listener, allows us to specify 'resources'. type: http -- cgit 1.4.1 From a9c1b419a9c913f8cfb373335c3b7824abcf7406 Mon Sep 17 00:00:00 2001 From: Johannes Löthberg Date: Sun, 18 Dec 2016 23:16:43 +0100 Subject: Bump twisted dependency MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit At least 16.0.0 is needed for wrapClientTLS support. Signed-off-by: Johannes Löthberg --- synapse/python_dependencies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 3742a25b37..7817b0cd91 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -24,7 +24,7 @@ REQUIREMENTS = { "signedjson>=1.0.0": ["signedjson>=1.0.0"], "pynacl==0.3.0": ["nacl==0.3.0", "nacl.bindings"], "service_identity>=1.0.0": ["service_identity>=1.0.0"], - "Twisted>=15.1.0": ["twisted>=15.1.0"], + "Twisted>=16.0.0": ["twisted>=16.0.0"], "pyopenssl>=0.14": ["OpenSSL>=0.14"], "pyyaml": ["yaml"], "pyasn1": ["pyasn1"], -- cgit 1.4.1 From f2a5aebf98a04bc4250ce2800b51d42543b5f35f Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sun, 18 Dec 2016 22:25:21 +0000 Subject: fix ability to change password to a non-ascii one https://github.com/vector-im/riot-web/issues/2658 --- synapse/handlers/auth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 3b146f09d6..652efba455 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -656,8 +656,8 @@ class AuthHandler(BaseHandler): Whether self.hash(password) == stored_hash (bool). """ if stored_hash: - return bcrypt.hashpw(password + self.hs.config.password_pepper, - stored_hash.encode('utf-8')) == stored_hash + return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper, + stored_hash.encode('utf8')) == stored_hash else: return False -- cgit 1.4.1 From 702c020e58dd8663e66c25f88fba2a02264e7357 Mon Sep 17 00:00:00 2001 From: Johannes Löthberg Date: Tue, 20 Dec 2016 01:37:50 +0100 Subject: Fix check for bind_address MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The empty string is a valid setting for the bind_address option, so explicitly check for None here instead. Signed-off-by: Johannes Löthberg --- synapse/app/appservice.py | 4 ++-- synapse/app/client_reader.py | 4 ++-- synapse/app/federation_reader.py | 4 ++-- synapse/app/federation_sender.py | 4 ++-- synapse/app/homeserver.py | 4 ++-- synapse/app/media_repository.py | 4 ++-- synapse/app/pusher.py | 4 ++-- synapse/app/synchrotron.py | 4 ++-- 8 files changed, 16 insertions(+), 16 deletions(-) diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index e24c1e1eda..c1379fdd7d 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -87,7 +87,7 @@ class AppserviceServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - if bind_address: + if bind_address is not None: bind_addresses.append(bind_address) for address in bind_addresses: @@ -112,7 +112,7 @@ class AppserviceServer(HomeServer): bind_address = listener.get("bind_address", None) bind_addresses = listener.get("bind_addresses", []) - if bind_address: + if bind_address is not None: bind_addresses.append(bind_address) for address in bind_addresses: diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index 305a82b664..b5e1d659e6 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -110,7 +110,7 @@ class ClientReaderServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - if bind_address: + if bind_address is not None: bind_addresses.append(bind_address) for address in bind_addresses: @@ -135,7 +135,7 @@ class ClientReaderServer(HomeServer): bind_address = listener.get("bind_address", None) bind_addresses = listener.get("bind_addresses", []) - if bind_address: + if bind_address is not None: bind_addresses.append(bind_address) for address in bind_addresses: diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 321dfc7cd5..c6810b83db 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -101,7 +101,7 @@ class FederationReaderServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - if bind_address: + if bind_address is not None: bind_addresses.append(bind_address) for address in bind_addresses: @@ -126,7 +126,7 @@ class FederationReaderServer(HomeServer): bind_address = listener.get("bind_address", None) bind_addresses = listener.get("bind_addresses", []) - if bind_address: + if bind_address is not None: bind_addresses.append(bind_address) for address in bind_addresses: diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 8092fd316c..23aae8a09c 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -93,7 +93,7 @@ class FederationSenderServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - if bind_address: + if bind_address is not None: bind_addresses.append(bind_address) for address in bind_addresses: @@ -118,7 +118,7 @@ class FederationSenderServer(HomeServer): bind_address = listener.get("bind_address", None) bind_addresses = listener.get("bind_addresses", []) - if bind_address: + if bind_address is not None: bind_addresses.append(bind_address) for address in bind_addresses: diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 2d6becad1a..6c69ccd7e2 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -175,7 +175,7 @@ class SynapseHomeServer(HomeServer): root_resource = create_resource_tree(resources, root_resource) - if bind_address: + if bind_address is not None: bind_addresses.append(bind_address) if tls: @@ -215,7 +215,7 @@ class SynapseHomeServer(HomeServer): bind_address = listener.get("bind_address", None) bind_addresses = listener.get("bind_addresses", []) - if bind_address: + if bind_address is not None: bind_addresses.append(bind_address) for address in bind_addresses: diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index c121107245..a47283e520 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -107,7 +107,7 @@ class MediaRepositoryServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - if bind_address: + if bind_address is not None: bind_addresses.append(bind_address) for address in bind_addresses: @@ -132,7 +132,7 @@ class MediaRepositoryServer(HomeServer): bind_address = listener.get("bind_address", None) bind_addresses = listener.get("bind_addresses", []) - if bind_address: + if bind_address is not None: bind_addresses.append(bind_address) for address in bind_addresses: diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index 159850c44c..a3df375c81 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -132,7 +132,7 @@ class PusherServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - if bind_address: + if bind_address is not None: bind_addresses.append(bind_address) for address in bind_addresses: @@ -168,7 +168,7 @@ class PusherServer(HomeServer): bind_address = listener.get("bind_address", None) bind_addresses = listener.get("bind_addresses", []) - if bind_address: + if bind_address is not None: bind_addresses.append(bind_address) for address in bind_addresses: diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 56143d0025..439daaa60a 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -312,7 +312,7 @@ class SynchrotronServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - if bind_address: + if bind_address is not None: bind_addresses.append(bind_address) for address in bind_addresses: @@ -337,7 +337,7 @@ class SynchrotronServer(HomeServer): bind_address = listener.get("bind_address", None) bind_addresses = listener.get("bind_addresses", []) - if bind_address: + if bind_address is not None: bind_addresses.append(bind_address) for address in bind_addresses: -- cgit 1.4.1 From 0c88ab184422739a20289ca213861986f70ae6e6 Mon Sep 17 00:00:00 2001 From: David Baker Date: Tue, 20 Dec 2016 18:27:30 +0000 Subject: Add /account/3pid/delete endpoint Also fix a typo in a comment --- synapse/handlers/auth.py | 11 ++++++++++ synapse/rest/client/v2_alpha/account.py | 36 ++++++++++++++++++++++++++++++++- synapse/storage/registration.py | 11 ++++++++++ 3 files changed, 57 insertions(+), 1 deletion(-) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 652efba455..ebadace4c1 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -617,6 +617,17 @@ class AuthHandler(BaseHandler): self.hs.get_clock().time_msec() ) + @defer.inlineCallbacks + def delete_threepid(self, user_id, medium, address): + # 'Canonicalise' email addresses as per above + if medium == 'email': + address = address.lower() + + ret = yield self.store.user_delete_threepid( + user_id, medium, address, + ) + defer.returnValue(ret) + def _save_session(self, session): # TODO: Persistent storage logger.debug("Saving session %s", session) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index eb49ad62e9..e74e5e0123 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -241,7 +241,7 @@ class ThreepidRestServlet(RestServlet): for reqd in ['medium', 'address', 'validated_at']: if reqd not in threepid: - logger.warn("Couldn't add 3pid: invalid response from ID sevrer") + logger.warn("Couldn't add 3pid: invalid response from ID server") raise SynapseError(500, "Invalid response from ID Server") yield self.auth_handler.add_threepid( @@ -263,9 +263,43 @@ class ThreepidRestServlet(RestServlet): defer.returnValue((200, {})) +class ThreepidDeleteRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/account/3pid/delete$", releases=()) + + def __init__(self, hs): + super(ThreepidDeleteRestServlet, self).__init__() + self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() + + @defer.inlineCallbacks + def on_POST(self, request): + yield run_on_reactor() + + body = parse_json_object_from_request(request) + + required = ['medium', 'address'] + absent = [] + for k in required: + if k not in body: + absent.append(k) + + if absent: + raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + + requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + + yield self.auth_handler.delete_threepid( + user_id, body['medium'], body['address'] + ) + + defer.returnValue((200, {})) + + def register_servlets(hs, http_server): PasswordRequestTokenRestServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server) ThreepidRequestTokenRestServlet(hs).register(http_server) ThreepidRestServlet(hs).register(http_server) + ThreepidDeleteRestServlet(hs).register(http_server) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 983a8ec52b..26be6060c3 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -413,6 +413,17 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): desc="user_delete_threepids", ) + def user_delete_threepid(self, user_id, medium, address): + return self._simple_delete( + "user_threepids", + keyvalues={ + "user_id": user_id, + "medium": medium, + "address": address, + }, + desc="user_delete_threepids", + ) + @defer.inlineCallbacks def count_all_users(self): """Counts all users registered on the homeserver.""" -- cgit 1.4.1 From 84cf00c6450fb0334e4eac186f05ae6cd6afe2bb Mon Sep 17 00:00:00 2001 From: David Baker Date: Wed, 21 Dec 2016 09:44:03 +0000 Subject: Fix another comment typo --- synapse/handlers/auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index ebadace4c1..221d7ea7a2 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -607,7 +607,7 @@ class AuthHandler(BaseHandler): # types (mediums) of threepid. For now, we still use the existing # infrastructure, but this is the start of synapse gaining knowledge # of specific types of threepid (and fixes the fact that checking - # for the presenc eof an email address during password reset was + # for the presence of an email address during password reset was # case sensitive). if medium == 'email': address = address.lower() -- cgit 1.4.1 From 555d702e3441e355d6b5e23702ab9505728dea71 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sat, 31 Dec 2016 15:21:37 +0000 Subject: limit total timeout for get_missing_events to 10s --- synapse/federation/federation_client.py | 4 +++- synapse/federation/federation_server.py | 5 +++++ synapse/federation/transport/client.py | 5 +++-- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 6851f2376d..b4bcec77ed 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -707,7 +707,7 @@ class FederationClient(FederationBase): @defer.inlineCallbacks def get_missing_events(self, destination, room_id, earliest_events_ids, - latest_events, limit, min_depth): + latest_events, limit, min_depth, timeout): """Tries to fetch events we are missing. This is called when we receive an event without having received all of its ancestors. @@ -721,6 +721,7 @@ class FederationClient(FederationBase): have all previous events for. limit (int): Maximum number of events to return. min_depth (int): Minimum depth of events tor return. + timeout (int): Max time to wait in ms """ try: content = yield self.transport_layer.get_missing_events( @@ -730,6 +731,7 @@ class FederationClient(FederationBase): latest_events=[e.event_id for e in latest_events], limit=limit, min_depth=min_depth, + timeout=timeout, ) events = [ diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index f4c60e67e3..6d76e6f917 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -425,6 +425,7 @@ class FederationServer(FederationBase): " limit: %d, min_depth: %d", earliest_events, latest_events, limit, min_depth ) + missing_events = yield self.handler.on_get_missing_events( origin, room_id, earliest_events, latest_events, limit, min_depth ) @@ -567,6 +568,9 @@ class FederationServer(FederationBase): len(prevs - seen), pdu.room_id, list(prevs - seen)[:5] ) + # XXX: we set timeout to 10s to help workaround + # https://github.com/matrix-org/synapse/issues/1733 + missing_events = yield self.get_missing_events( origin, pdu.room_id, @@ -574,6 +578,7 @@ class FederationServer(FederationBase): latest_events=[pdu], limit=10, min_depth=min_depth, + timeout=10000, ) # We want to sort these by depth so we process them and diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 491cdc29e1..915af34409 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -386,7 +386,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def get_missing_events(self, destination, room_id, earliest_events, - latest_events, limit, min_depth): + latest_events, limit, min_depth, timeout): path = PREFIX + "/get_missing_events/%s" % (room_id,) content = yield self.client.post_json( @@ -397,7 +397,8 @@ class TransportLayerClient(object): "min_depth": int(min_depth), "earliest_events": earliest_events, "latest_events": latest_events, - } + }, + timeout=timeout, ) defer.returnValue(content) -- cgit 1.4.1 From 8e82611f3726bfd577ca77a39328c63ecb29410f Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Thu, 5 Jan 2017 11:44:44 +0000 Subject: fix comment --- synapse/federation/federation_server.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 6d76e6f917..800f04189f 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -569,7 +569,23 @@ class FederationServer(FederationBase): ) # XXX: we set timeout to 10s to help workaround - # https://github.com/matrix-org/synapse/issues/1733 + # https://github.com/matrix-org/synapse/issues/1733. + # The reason is to avoid holding the linearizer lock + # whilst processing inbound /send transactions, causing + # FDs to stack up and block other inbound transactions + # which empirically can currently take up to 30 minutes. + # + # N.B. this explicitly disables retry attempts. + # + # N.B. this also increases our chances of falling back to + # fetching fresh state for the room if the missing event + # can't be found, which slightly reduces our security. + # it may also increase our DAG extremity count for the room, + # causing additional state resolution? See #1760. + # However, fetching state doesn't hold the linearizer lock + # apparently. + # + # see https://github.com/matrix-org/synapse/pull/1744 missing_events = yield self.get_missing_events( origin, -- cgit 1.4.1 From 8404f132c3fec8ca80184fe86302d653aac164ea Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 6 Jan 2017 23:28:46 +0000 Subject: Revert "fix typo breaking the fix to #1753" This reverts commit b2850e62db376ea920fed9dff65a47c15cb0dc68. --- synapse/events/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index da9f3ad436..8c71aeb5e4 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -43,7 +43,7 @@ class _EventInternalMetadata(object): returns a str with the name of the server this event is sent on behalf of. """ - return getattr(self, "send_on_behalf_of", None) + return getattr(self, "get_send_on_behalf_of", None) def _event_dict_property(key): -- cgit 1.4.1 From 2f4b2f4783c608e27dba91e528d2c5a9032b0051 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sat, 7 Jan 2017 04:00:42 +0000 Subject: gah, fix mangled merge of 0.18.7 into develop --- synapse/events/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 8c71aeb5e4..da9f3ad436 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -43,7 +43,7 @@ class _EventInternalMetadata(object): returns a str with the name of the server this event is sent on behalf of. """ - return getattr(self, "get_send_on_behalf_of", None) + return getattr(self, "send_on_behalf_of", None) def _event_dict_property(key): -- cgit 1.4.1 From f7085ac84f76ee621ea52c9eaa0399c786d14027 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 9 Jan 2017 17:17:10 +0000 Subject: Name linearizer's for better logs --- synapse/federation/federation_server.py | 4 ++-- synapse/handlers/room_member.py | 2 +- synapse/rest/media/v1/media_repository.py | 2 +- synapse/state.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 1fee4e83a6..862ccbef5d 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -52,8 +52,8 @@ class FederationServer(FederationBase): self.auth = hs.get_auth() - self._room_pdu_linearizer = Linearizer() - self._server_linearizer = Linearizer() + self._room_pdu_linearizer = Linearizer("fed_room_pdu") + self._server_linearizer = Linearizer("fed_server") # We cache responses to state queries, as they take a while and often # come in waves. diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 2f8782e522..649aaf6d29 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -45,7 +45,7 @@ class RoomMemberHandler(BaseHandler): def __init__(self, hs): super(RoomMemberHandler, self).__init__(hs) - self.member_linearizer = Linearizer() + self.member_linearizer = Linearizer(name="member") self.clock = hs.get_clock() diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 692e078419..2b693ae548 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -61,7 +61,7 @@ class MediaRepository(object): self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements - self.remote_media_linearizer = Linearizer() + self.remote_media_linearizer = Linearizer(name="media_remote") self.recently_accessed_remotes = set() diff --git a/synapse/state.py b/synapse/state.py index 8003099c88..b9d5627a82 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -89,7 +89,7 @@ class StateHandler(object): # dict of set of event_ids -> _StateCacheEntry. self._state_cache = None - self.resolve_linearizer = Linearizer() + self.resolve_linearizer = Linearizer(name="state_resolve_lock") def start_caching(self): logger.debug("start_caching") -- cgit 1.4.1 From 6823fe52410db3b95df720b7955ad7b617dc7dee Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 9 Jan 2017 18:25:13 +0000 Subject: Linearize updates to membership via PUT /state/ --- synapse/handlers/room_member.py | 19 +++++++++++++++---- synapse/rest/client/v1/room.py | 28 +++++++++++++++++----------- tests/rest/client/v1/test_rooms.py | 4 ++-- tests/rest/client/v1/utils.py | 5 ++++- 4 files changed, 38 insertions(+), 18 deletions(-) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 2f8782e522..8e7bbe9f75 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -89,7 +89,7 @@ class RoomMemberHandler(BaseHandler): duplicate = yield msg_handler.deduplicate_state_event(event, context) if duplicate is not None: # Discard the new event since this membership change is a no-op. - return + defer.returnValue(duplicate) yield msg_handler.handle_new_client_event( requester, @@ -120,6 +120,8 @@ class RoomMemberHandler(BaseHandler): if prev_member_event.membership == Membership.JOIN: user_left_room(self.distributor, target, room_id) + defer.returnValue(event) + @defer.inlineCallbacks def remote_join(self, remote_room_hosts, room_id, user, content): if len(remote_room_hosts) == 0: @@ -187,6 +189,7 @@ class RoomMemberHandler(BaseHandler): ratelimit=True, content=None, ): + content_specified = bool(content) if content is None: content = {} @@ -229,6 +232,12 @@ class RoomMemberHandler(BaseHandler): errcode=Codes.BAD_STATE ) + same_content = content == old_state.content + same_membership = old_membership == effective_membership_state + same_sender = requester.user.to_string() == old_state.sender + if same_sender and same_membership and same_content: + defer.returnValue(old_state) + is_host_in_room = yield self._is_host_in_room(current_state_ids) if effective_membership_state == Membership.JOIN: @@ -247,8 +256,9 @@ class RoomMemberHandler(BaseHandler): content["membership"] = Membership.JOIN profile = self.hs.get_handlers().profile_handler - content["displayname"] = yield profile.get_displayname(target) - content["avatar_url"] = yield profile.get_avatar_url(target) + if not content_specified: + content["displayname"] = yield profile.get_displayname(target) + content["avatar_url"] = yield profile.get_avatar_url(target) if requester.is_guest: content["kind"] = "guest" @@ -290,7 +300,7 @@ class RoomMemberHandler(BaseHandler): defer.returnValue({}) - yield self._local_membership_update( + res = yield self._local_membership_update( requester=requester, target=target, room_id=room_id, @@ -300,6 +310,7 @@ class RoomMemberHandler(BaseHandler): prev_event_ids=latest_event_ids, content=content, ) + defer.returnValue(res) @defer.inlineCallbacks def send_membership_event( diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index eead435bfd..2ebf5e59a0 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -152,23 +152,29 @@ class RoomStateEventRestServlet(ClientV1RestServlet): if state_key is not None: event_dict["state_key"] = state_key - msg_handler = self.handlers.message_handler - event, context = yield msg_handler.create_event( - event_dict, - token_id=requester.access_token_id, - txn_id=txn_id, - ) - if event_type == EventTypes.Member: - yield self.handlers.room_member_handler.send_membership_event( + membership = content.get("membership", None) + event = yield self.handlers.room_member_handler.update_membership( requester, - event, - context, + target=UserID.from_string(state_key), + room_id=room_id, + action=membership, + content=content, ) else: + msg_handler = self.handlers.message_handler + event, context = yield msg_handler.create_event( + event_dict, + token_id=requester.access_token_id, + txn_id=txn_id, + ) + yield msg_handler.send_nonmember_event(requester, event, context) - defer.returnValue((200, {"event_id": event.event_id})) + ret = {} + if event: + ret = {"event_id": event.event_id} + defer.returnValue((200, ret)) # TODO: Needs unit testing for generic events + feedback diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 4fe99ebc0b..6bce352c5f 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -259,8 +259,8 @@ class RoomPermissionsTestCase(RestTestCase): # set [invite/join/left] of self, set [invite/join/left] of other, # expect all 404s because room doesn't exist on any server for usr in [self.user_id, self.rmcreator_id]: - yield self.join(room=room, user=usr, expect_code=403) - yield self.leave(room=room, user=usr, expect_code=403) + yield self.join(room=room, user=usr, expect_code=404) + yield self.leave(room=room, user=usr, expect_code=404) @defer.inlineCallbacks def test_membership_private_room_perms(self): diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 17524b2e23..3bb1dd003a 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -87,7 +87,10 @@ class RestTestCase(unittest.TestCase): (code, response) = yield self.mock_resource.trigger( "PUT", path, json.dumps(data) ) - self.assertEquals(expect_code, code, msg=str(response)) + self.assertEquals( + expect_code, code, + msg="Expected: %d, got: %d, resp: %r" % (expect_code, code, response) + ) self.auth_user_id = temp_id -- cgit 1.4.1 From 586f474a44cf320c7d578714aea5113f2701073c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Jan 2017 12:46:00 +0000 Subject: Don't block messages sending on bumping presence --- synapse/handlers/message.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 7a57a69bd3..59d0ad3bdc 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -279,7 +279,7 @@ class MessageHandler(BaseHandler): if event.type == EventTypes.Message: presence = self.hs.get_presence_handler() - yield presence.bump_presence_active_time(user) + preserve_fn(presence.bump_presence_active_time)(user) @defer.inlineCallbacks def deduplicate_state_event(self, event, context): -- cgit 1.4.1 From f477370c0cb5b3f6bee74eb2f8209f7fdf5db9fe Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Jan 2017 14:04:13 +0000 Subject: Add paranoia exception catch in Linearizer --- synapse/util/async.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/synapse/util/async.py b/synapse/util/async.py index 83875edc85..35380bf8ed 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -192,8 +192,11 @@ class Linearizer(object): logger.info( "Waiting to acquire linearizer lock %r for key %r", self.name, key ) - with PreserveLoggingContext(): - yield current_defer + try: + with PreserveLoggingContext(): + yield current_defer + except: + logger.exception("Unexpected exception in Linearizer") logger.info("Acquired linearizer lock %r for key %r", self.name, key) -- cgit 1.4.1 From 3b4de17d2be01a0726152f614eb03078979ec631 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Jan 2017 14:05:53 +0000 Subject: Comment --- synapse/handlers/message.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 59d0ad3bdc..88bd2d572e 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -279,6 +279,8 @@ class MessageHandler(BaseHandler): if event.type == EventTypes.Message: presence = self.hs.get_presence_handler() + # We don't want to block sending messages on any presence code. This + # matters as sometimes presence code can take a while. preserve_fn(presence.bump_presence_active_time)(user) @defer.inlineCallbacks -- cgit 1.4.1 From 32019c98971fae775fe79bb30615899e7f6b09d4 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 10 Jan 2017 14:19:50 +0000 Subject: Log which files we saved attachments to in the media_repository --- synapse/rest/media/v1/media_repository.py | 4 ++++ synapse/rest/media/v1/thumbnailer.py | 5 +++++ synapse/rest/media/v1/upload_resource.py | 2 ++ 3 files changed, 11 insertions(+) diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 2b693ae548..3cbeca503c 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -98,6 +98,8 @@ class MediaRepository(object): with open(fname, "wb") as f: f.write(content) + logger.info("Stored local media in file %r", fname) + yield self.store.store_local_media( media_id=media_id, media_type=media_type, @@ -190,6 +192,8 @@ class MediaRepository(object): else: upload_name = None + logger.info("Stored remote media in file %r", fname) + yield self.store.store_cached_remote_media( origin=server_name, media_id=media_id, diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 0bb3676844..3868d4f65f 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -16,6 +16,10 @@ import PIL.Image as Image from io import BytesIO +import logging + +logger = logging.getLogger(__name__) + class Thumbnailer(object): @@ -86,4 +90,5 @@ class Thumbnailer(object): output_bytes = output_bytes_io.getvalue() with open(output_path, "wb") as output_file: output_file.write(output_bytes) + logger.info("Stored thumbnail in file %r", output_path) return len(output_bytes) diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index b716d1d892..4ab33f73bf 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -97,6 +97,8 @@ class UploadResource(Resource): content_length, requester.user ) + logger.info("Uploaded content with URI %r", content_uri) + respond_with_json( request, 200, {"content_uri": content_uri}, send_cors=True ) -- cgit 1.4.1 From dd52d4de4c9318b377a880b2662124688fc23129 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Jan 2017 14:34:50 +0000 Subject: Limit number of entries to prefill from cache Some tables, like device_inbox, take a long time to query at startup for the stream change cache prefills. This is likely because they are slower growing streams and so are more fragmented on disk. For now, lets pull fewer entries out to make startup quicker. In future, we should add a better index to make it even faster. --- synapse/storage/__init__.py | 4 +++- synapse/storage/_base.py | 5 +++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index fe936b3e62..e8495f1eb9 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -189,7 +189,8 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "device_inbox", entity_column="user_id", stream_column="stream_id", - max_value=max_device_inbox_id + max_value=max_device_inbox_id, + limit=1000, ) self._device_inbox_stream_cache = StreamChangeCache( "DeviceInboxStreamChangeCache", min_device_inbox_id, @@ -202,6 +203,7 @@ class DataStore(RoomMemberStore, RoomStore, entity_column="destination", stream_column="stream_id", max_value=max_device_inbox_id, + limit=1000, ) self._device_federation_outbox_stream_cache = StreamChangeCache( "DeviceFederationOutboxStreamChangeCache", min_device_outbox_id, diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index b62c459d8b..5620a655eb 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -838,18 +838,19 @@ class SQLBaseStore(object): return txn.execute(sql, keyvalues.values()) def _get_cache_dict(self, db_conn, table, entity_column, stream_column, - max_value): + max_value, limit=100000): # Fetch a mapping of room_id -> max stream position for "recent" rooms. # It doesn't really matter how many we get, the StreamChangeCache will # do the right thing to ensure it respects the max size of cache. sql = ( "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" - " WHERE %(stream)s > ? - 100000" + " WHERE %(stream)s > ? - %(limit)s" " GROUP BY %(entity)s" ) % { "table": table, "entity": entity_column, "stream": stream_column, + "limit": limit, } sql = self.database_engine.convert_param_style(sql) -- cgit 1.4.1 From caddadfc5ac61d1c91fbaf29bf3298f90a140560 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Jan 2017 15:04:57 +0000 Subject: Change device_inbox stream index to include user This makes fetching the nost recently changed users much tricker, and brings it in line with e.g. presence_stream indices. --- synapse/storage/deviceinbox.py | 38 ++++++++++++++++++++++-- synapse/storage/prepare_database.py | 2 +- synapse/storage/schema/delta/40/device_inbox.sql | 20 +++++++++++++ 3 files changed, 57 insertions(+), 3 deletions(-) create mode 100644 synapse/storage/schema/delta/40/device_inbox.sql diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index 2821eb89c9..b71ac3ae39 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -18,13 +18,29 @@ import ujson from twisted.internet import defer -from ._base import SQLBaseStore +from .background_updates import BackgroundUpdateStore logger = logging.getLogger(__name__) -class DeviceInboxStore(SQLBaseStore): +class DeviceInboxStore(BackgroundUpdateStore): + DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" + + def __init__(self, hs): + super(DeviceInboxStore, self).__init__(hs) + + self.register_background_index_update( + "device_inbox_stream_index", + index_name="device_inbox_stream_id_user_id", + table="device_inbox", + columns=["stream_id", "user_id"], + ) + + self.register_background_update_handler( + self.DEVICE_INBOX_STREAM_ID, + self._background_drop_index_device_inbox, + ) @defer.inlineCallbacks def add_messages_to_device_inbox(self, local_messages_by_user_then_device, @@ -368,3 +384,21 @@ class DeviceInboxStore(SQLBaseStore): "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn ) + + @defer.inlineCallbacks + def _background_drop_index_device_inbox(self, progress, batch_size): + def reindex_txn(conn): + conn.set_session(autocommit=True) + try: + txn = conn.cursor() + txn.execute( + "DROP INDEX IF EXISTS device_inbox_stream_id" + ) + finally: + conn.set_session(autocommit=False) + + yield self.runWithConnection(reindex_txn) + + yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID) + + defer.returnValue(1) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index e46ae6502e..b357f22be7 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 39 +SCHEMA_VERSION = 40 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/schema/delta/40/device_inbox.sql b/synapse/storage/schema/delta/40/device_inbox.sql new file mode 100644 index 0000000000..ce58fe2082 --- /dev/null +++ b/synapse/storage/schema/delta/40/device_inbox.sql @@ -0,0 +1,20 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('device_inbox_stream_index', '{}'); + +INSERT into background_updates (update_name, progress_json, depends_on) + VALUES ('device_inbox_stream_drop', '{}', 'device_inbox_stream_index'); -- cgit 1.4.1 From 5a32e9273ec9759caf09d5b8204dd29e7a007b97 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Jan 2017 15:11:27 +0000 Subject: Don't disable autocommit --- synapse/storage/deviceinbox.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index b71ac3ae39..b0ab70bafe 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -388,14 +388,10 @@ class DeviceInboxStore(BackgroundUpdateStore): @defer.inlineCallbacks def _background_drop_index_device_inbox(self, progress, batch_size): def reindex_txn(conn): - conn.set_session(autocommit=True) - try: - txn = conn.cursor() - txn.execute( - "DROP INDEX IF EXISTS device_inbox_stream_id" - ) - finally: - conn.set_session(autocommit=False) + txn = conn.cursor() + txn.execute( + "DROP INDEX IF EXISTS device_inbox_stream_id" + ) yield self.runWithConnection(reindex_txn) -- cgit 1.4.1 From ab655dca339f8d4168079cc2b4529dc50265fc83 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Jan 2017 15:15:25 +0000 Subject: Explicitly close the cursor --- synapse/storage/deviceinbox.py | 1 + 1 file changed, 1 insertion(+) diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index b0ab70bafe..bde3b5cbbc 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -392,6 +392,7 @@ class DeviceInboxStore(BackgroundUpdateStore): txn.execute( "DROP INDEX IF EXISTS device_inbox_stream_id" ) + txn.close() yield self.runWithConnection(reindex_txn) -- cgit 1.4.1 From 8a0fddfd73784b0e06c8d61337279c45c96e8687 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Jan 2017 16:30:53 +0000 Subject: Remove spurious for..else.. --- synapse/app/pusher.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index a3df375c81..57e097fa11 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -146,17 +146,6 @@ class PusherServer(HomeServer): ), interface=address ) - else: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=bind_address - ) logger.info("Synapse pusher now listening on port %d", port) -- cgit 1.4.1 From 3cb1799347f0100418052db3b0ed18c229b22456 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Tue, 10 Jan 2017 16:50:35 +0000 Subject: credit patrik properly --- CHANGES.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGES.rst b/CHANGES.rst index 68e9d8c671..9106134b46 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -30,6 +30,7 @@ Changes in synapse v0.18.6 (2017-01-06) Bug fixes: * Fix bug when checking if a guest user is allowed to join a room (PR #1772) + Thanks to Patrik Oldsberg for diagnosing and the fix! Changes in synapse v0.18.6-rc3 (2017-01-05) -- cgit 1.4.1 From edd6cdfc9a1cf180871657baaf2aa6da5845f25a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Jan 2017 17:21:41 +0000 Subject: Restore default bind address --- synapse/app/appservice.py | 12 ++---------- synapse/app/client_reader.py | 12 ++---------- synapse/app/federation_reader.py | 12 ++---------- synapse/app/federation_sender.py | 12 ++---------- synapse/app/homeserver.py | 12 ++---------- synapse/app/media_repository.py | 12 ++---------- synapse/app/pusher.py | 12 ++---------- synapse/app/synchrotron.py | 12 ++---------- synapse/config/server.py | 17 +++++++++++++---- 9 files changed, 29 insertions(+), 84 deletions(-) diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index c1379fdd7d..1900930053 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -76,8 +76,7 @@ class AppserviceServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", None) - bind_addresses = listener_config.get("bind_addresses", []) + bind_addresses = listener_config["bind_addresses"] site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -87,9 +86,6 @@ class AppserviceServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - if bind_address is not None: - bind_addresses.append(bind_address) - for address in bind_addresses: reactor.listenTCP( port, @@ -109,11 +105,7 @@ class AppserviceServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_address = listener.get("bind_address", None) - bind_addresses = listener.get("bind_addresses", []) - - if bind_address is not None: - bind_addresses.append(bind_address) + bind_addresses = listener["bind_addresses"] for address in bind_addresses: reactor.listenTCP( diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index b5e1d659e6..4d081eccd1 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -90,8 +90,7 @@ class ClientReaderServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", None) - bind_addresses = listener_config.get("bind_addresses", []) + bind_addresses = listener_config["bind_addresses"] site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -110,9 +109,6 @@ class ClientReaderServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - if bind_address is not None: - bind_addresses.append(bind_address) - for address in bind_addresses: reactor.listenTCP( port, @@ -132,11 +128,7 @@ class ClientReaderServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_address = listener.get("bind_address", None) - bind_addresses = listener.get("bind_addresses", []) - - if bind_address is not None: - bind_addresses.append(bind_address) + bind_addresses = listener["bind_addresses"] for address in bind_addresses: reactor.listenTCP( diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index c6810b83db..90a4816753 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -86,8 +86,7 @@ class FederationReaderServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", None) - bind_addresses = listener_config.get("bind_addresses", []) + bind_addresses = listener_config["bind_addresses"] site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -101,9 +100,6 @@ class FederationReaderServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - if bind_address is not None: - bind_addresses.append(bind_address) - for address in bind_addresses: reactor.listenTCP( port, @@ -123,11 +119,7 @@ class FederationReaderServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_address = listener.get("bind_address", None) - bind_addresses = listener.get("bind_addresses", []) - - if bind_address is not None: - bind_addresses.append(bind_address) + bind_addresses = listener["bind_addresses"] for address in bind_addresses: reactor.listenTCP( diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 23aae8a09c..ec06620efb 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -82,8 +82,7 @@ class FederationSenderServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", None) - bind_addresses = listener_config.get("bind_addresses", []) + bind_addresses = listener_config["bind_addresses"] site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -93,9 +92,6 @@ class FederationSenderServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - if bind_address is not None: - bind_addresses.append(bind_address) - for address in bind_addresses: reactor.listenTCP( port, @@ -115,11 +111,7 @@ class FederationSenderServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_address = listener.get("bind_address", None) - bind_addresses = listener.get("bind_addresses", []) - - if bind_address is not None: - bind_addresses.append(bind_address) + bind_addresses = listener["bind_addresses"] for address in bind_addresses: reactor.listenTCP( diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 6c69ccd7e2..e0b87468fe 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -107,8 +107,7 @@ def build_resource_for_web_client(hs): class SynapseHomeServer(HomeServer): def _listener_http(self, config, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", None) - bind_addresses = listener_config.get("bind_addresses", []) + bind_addresses = listener_config["bind_addresses"] tls = listener_config.get("tls", False) site_tag = listener_config.get("tag", port) @@ -175,9 +174,6 @@ class SynapseHomeServer(HomeServer): root_resource = create_resource_tree(resources, root_resource) - if bind_address is not None: - bind_addresses.append(bind_address) - if tls: for address in bind_addresses: reactor.listenSSL( @@ -212,11 +208,7 @@ class SynapseHomeServer(HomeServer): if listener["type"] == "http": self._listener_http(config, listener) elif listener["type"] == "manhole": - bind_address = listener.get("bind_address", None) - bind_addresses = listener.get("bind_addresses", []) - - if bind_address is not None: - bind_addresses.append(bind_address) + bind_addresses = listener["bind_addresses"] for address in bind_addresses: reactor.listenTCP( diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index a47283e520..ef17b158a5 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -87,8 +87,7 @@ class MediaRepositoryServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", None) - bind_addresses = listener_config.get("bind_addresses", []) + bind_addresses = listener_config["bind_addresses"] site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -107,9 +106,6 @@ class MediaRepositoryServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - if bind_address is not None: - bind_addresses.append(bind_address) - for address in bind_addresses: reactor.listenTCP( port, @@ -129,11 +125,7 @@ class MediaRepositoryServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_address = listener.get("bind_address", None) - bind_addresses = listener.get("bind_addresses", []) - - if bind_address is not None: - bind_addresses.append(bind_address) + bind_addresses = listener["bind_addresses"] for address in bind_addresses: reactor.listenTCP( diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index 57e097fa11..073f2c2489 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -121,8 +121,7 @@ class PusherServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", None) - bind_addresses = listener_config.get("bind_addresses", []) + bind_addresses = listener_config["bind_addresses"] site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -132,9 +131,6 @@ class PusherServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - if bind_address is not None: - bind_addresses.append(bind_address) - for address in bind_addresses: reactor.listenTCP( port, @@ -154,11 +150,7 @@ class PusherServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_address = listener.get("bind_address", None) - bind_addresses = listener.get("bind_addresses", []) - - if bind_address is not None: - bind_addresses.append(bind_address) + bind_addresses = listener["bind_addresses"] for address in bind_addresses: reactor.listenTCP( diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 439daaa60a..4dfc2dc648 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -289,8 +289,7 @@ class SynchrotronServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", None) - bind_addresses = listener_config.get("bind_addresses", []) + bind_addresses = listener_config["bind_addresses"] site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -312,9 +311,6 @@ class SynchrotronServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - if bind_address is not None: - bind_addresses.append(bind_address) - for address in bind_addresses: reactor.listenTCP( port, @@ -334,11 +330,7 @@ class SynchrotronServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_address = listener.get("bind_address", None) - bind_addresses = listener.get("bind_addresses", []) - - if bind_address is not None: - bind_addresses.append(bind_address) + bind_addresses = listener["bind_addresses"] for address in bind_addresses: reactor.listenTCP( diff --git a/synapse/config/server.py b/synapse/config/server.py index 5e6b2a68a7..59687ee395 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -42,6 +42,15 @@ class ServerConfig(Config): self.listeners = config.get("listeners", []) + for listener in self.listeners: + bind_address = listener.get("bind_address", None) + bind_addresses = listener.setdefault("bind_addresses", []) + + if bind_address: + bind_addresses.append(bind_address) + elif not bind_addresses: + bind_addresses.append('') + self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None)) bind_port = config.get("bind_port") @@ -54,7 +63,7 @@ class ServerConfig(Config): self.listeners.append({ "port": bind_port, - "bind_address": bind_host, + "bind_addresses": [bind_host], "tls": True, "type": "http", "resources": [ @@ -73,7 +82,7 @@ class ServerConfig(Config): if unsecure_port: self.listeners.append({ "port": unsecure_port, - "bind_address": bind_host, + "bind_addresses": [bind_host], "tls": False, "type": "http", "resources": [ @@ -92,7 +101,7 @@ class ServerConfig(Config): if manhole: self.listeners.append({ "port": manhole, - "bind_address": "127.0.0.1", + "bind_addresses": ["127.0.0.1"], "type": "manhole", }) @@ -100,7 +109,7 @@ class ServerConfig(Config): if metrics_port: self.listeners.append({ "port": metrics_port, - "bind_address": config.get("metrics_bind_host", "127.0.0.1"), + "bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")], "tls": False, "type": "http", "resources": [ -- cgit 1.4.1 From b1dfd202928174ca5b377196b813e6ce51fe0999 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Jan 2017 17:23:18 +0000 Subject: Pop bind_address --- synapse/config/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/config/server.py b/synapse/config/server.py index 59687ee395..1f9999d57a 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -43,7 +43,7 @@ class ServerConfig(Config): self.listeners = config.get("listeners", []) for listener in self.listeners: - bind_address = listener.get("bind_address", None) + bind_address = listener.pop("bind_address", None) bind_addresses = listener.setdefault("bind_addresses", []) if bind_address: -- cgit 1.4.1 From 7e6c2937c327c76e32fb663d4a94072a0492c338 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Jan 2017 18:16:54 +0000 Subject: Split out static auth methods from Auth object --- synapse/api/auth.py | 986 ++++++++++++++++++++++++++++------------------------ 1 file changed, 531 insertions(+), 455 deletions(-) diff --git a/synapse/api/auth.py b/synapse/api/auth.py index f93e45a744..5e2b89c324 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -27,7 +27,6 @@ from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError from synapse.types import UserID, get_domain_from_id from synapse.util.logcontext import preserve_context_over_fn -from synapse.util.logutils import log_function from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -43,30 +42,9 @@ AuthEventTypes = ( GUEST_DEVICE_ID = "guest_device" -class Auth(object): - """ - FIXME: This class contains a mix of functions for authenticating users - of our client-server API and authenticating events added to room graphs. - """ - def __init__(self, hs): - self.hs = hs - self.clock = hs.get_clock() - self.store = hs.get_datastore() - self.state = hs.get_state_handler() - self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 - - @defer.inlineCallbacks - def check_from_context(self, event, context, do_sig_check=True): - auth_events_ids = yield self.compute_auth_events( - event, context.prev_state_ids, for_verification=True, - ) - auth_events = yield self.store.get_events(auth_events_ids) - auth_events = { - (e.type, e.state_key): e for e in auth_events.values() - } - self.check(event, auth_events=auth_events, do_sig_check=do_sig_check) - - def check(self, event, auth_events, do_sig_check=True): +class Auther(object): + @staticmethod + def check(event, auth_events, do_sig_check=True): """ Checks if this event is correctly authed. Args: @@ -77,133 +55,133 @@ class Auth(object): Returns: True if the auth checks pass. """ - with Measure(self.clock, "auth.check"): - self.check_size_limits(event) + Auther.check_size_limits(event) - if not hasattr(event, "room_id"): - raise AuthError(500, "Event has no room_id: %s" % event) + if not hasattr(event, "room_id"): + raise AuthError(500, "Event has no room_id: %s" % event) - if do_sig_check: - sender_domain = get_domain_from_id(event.sender) - event_id_domain = get_domain_from_id(event.event_id) + if do_sig_check: + sender_domain = get_domain_from_id(event.sender) + event_id_domain = get_domain_from_id(event.event_id) - is_invite_via_3pid = ( - event.type == EventTypes.Member - and event.membership == Membership.INVITE - and "third_party_invite" in event.content - ) + is_invite_via_3pid = ( + event.type == EventTypes.Member + and event.membership == Membership.INVITE + and "third_party_invite" in event.content + ) - # Check the sender's domain has signed the event - if not event.signatures.get(sender_domain): - # We allow invites via 3pid to have a sender from a different - # HS, as the sender must match the sender of the original - # 3pid invite. This is checked further down with the - # other dedicated membership checks. - if not is_invite_via_3pid: - raise AuthError(403, "Event not signed by sender's server") - - # Check the event_id's domain has signed the event - if not event.signatures.get(event_id_domain): - raise AuthError(403, "Event not signed by sending server") - - if auth_events is None: - # Oh, we don't know what the state of the room was, so we - # are trusting that this is allowed (at least for now) - logger.warn("Trusting event: %s", event.event_id) - return True + # Check the sender's domain has signed the event + if not event.signatures.get(sender_domain): + # We allow invites via 3pid to have a sender from a different + # HS, as the sender must match the sender of the original + # 3pid invite. This is checked further down with the + # other dedicated membership checks. + if not is_invite_via_3pid: + raise AuthError(403, "Event not signed by sender's server") + + # Check the event_id's domain has signed the event + if not event.signatures.get(event_id_domain): + raise AuthError(403, "Event not signed by sending server") + + if auth_events is None: + # Oh, we don't know what the state of the room was, so we + # are trusting that this is allowed (at least for now) + logger.warn("Trusting event: %s", event.event_id) + return True - if event.type == EventTypes.Create: - room_id_domain = get_domain_from_id(event.room_id) - if room_id_domain != sender_domain: - raise AuthError( - 403, - "Creation event's room_id domain does not match sender's" - ) - # FIXME - return True + if event.type == EventTypes.Create: + room_id_domain = get_domain_from_id(event.room_id) + if room_id_domain != sender_domain: + raise AuthError( + 403, + "Creation event's room_id domain does not match sender's" + ) + # FIXME + return True - creation_event = auth_events.get((EventTypes.Create, ""), None) + creation_event = auth_events.get((EventTypes.Create, ""), None) + + if not creation_event: + raise SynapseError( + 403, + "Room %r does not exist" % (event.room_id,) + ) - if not creation_event: - raise SynapseError( + creating_domain = get_domain_from_id(event.room_id) + originating_domain = get_domain_from_id(event.sender) + if creating_domain != originating_domain: + if not Auther.can_federate(event, auth_events): + raise AuthError( 403, - "Room %r does not exist" % (event.room_id,) + "This room has been marked as unfederatable." ) - creating_domain = get_domain_from_id(event.room_id) - originating_domain = get_domain_from_id(event.sender) - if creating_domain != originating_domain: - if not self.can_federate(event, auth_events): - raise AuthError( - 403, - "This room has been marked as unfederatable." - ) + # FIXME: Temp hack + if event.type == EventTypes.Aliases: + if not event.is_state(): + raise AuthError( + 403, + "Alias event must be a state event", + ) + if not event.state_key: + raise AuthError( + 403, + "Alias event must have non-empty state_key" + ) + sender_domain = get_domain_from_id(event.sender) + if event.state_key != sender_domain: + raise AuthError( + 403, + "Alias event's state_key does not match sender's domain" + ) + return True - # FIXME: Temp hack - if event.type == EventTypes.Aliases: - if not event.is_state(): - raise AuthError( - 403, - "Alias event must be a state event", - ) - if not event.state_key: - raise AuthError( - 403, - "Alias event must have non-empty state_key" - ) - sender_domain = get_domain_from_id(event.sender) - if event.state_key != sender_domain: - raise AuthError( - 403, - "Alias event's state_key does not match sender's domain" - ) - return True + logger.debug( + "Auth events: %s", + [a.event_id for a in auth_events.values()] + ) - logger.debug( - "Auth events: %s", - [a.event_id for a in auth_events.values()] + if event.type == EventTypes.Member: + allowed = Auther.is_membership_change_allowed( + event, auth_events ) + if allowed: + logger.debug("Allowing! %s", event) + else: + logger.debug("Denying! %s", event) + return allowed - if event.type == EventTypes.Member: - allowed = self.is_membership_change_allowed( - event, auth_events - ) - if allowed: - logger.debug("Allowing! %s", event) - else: - logger.debug("Denying! %s", event) - return allowed - - self.check_event_sender_in_room(event, auth_events) + Auther.check_event_sender_in_room(event, auth_events) - # Special case to allow m.room.third_party_invite events wherever - # a user is allowed to issue invites. Fixes - # https://github.com/vector-im/vector-web/issues/1208 hopefully - if event.type == EventTypes.ThirdPartyInvite: - user_level = self._get_user_power_level(event.user_id, auth_events) - invite_level = self._get_named_level(auth_events, "invite", 0) + # Special case to allow m.room.third_party_invite events wherever + # a user is allowed to issue invites. Fixes + # https://github.com/vector-im/vector-web/issues/1208 hopefully + if event.type == EventTypes.ThirdPartyInvite: + user_level = Auther._get_user_power_level(event.user_id, auth_events) + invite_level = Auther._get_named_level(auth_events, "invite", 0) - if user_level < invite_level: - raise AuthError( - 403, ( - "You cannot issue a third party invite for %s." % - (event.content.display_name,) - ) + if user_level < invite_level: + raise AuthError( + 403, ( + "You cannot issue a third party invite for %s." % + (event.content.display_name,) ) - else: - return True + ) + else: + return True - self._can_send_event(event, auth_events) + Auther._can_send_event(event, auth_events) - if event.type == EventTypes.PowerLevels: - self._check_power_levels(event, auth_events) + if event.type == EventTypes.PowerLevels: + Auther._check_power_levels(event, auth_events) - if event.type == EventTypes.Redaction: - self.check_redaction(event, auth_events) + if event.type == EventTypes.Redaction: + Auther.check_redaction(event, auth_events) - logger.debug("Allowing! %s", event) + logger.debug("Allowing! %s", event) - def check_size_limits(self, event): + @staticmethod + def check_size_limits(event): def too_big(field): raise EventSizeError("%s too large" % (field,)) @@ -220,109 +198,14 @@ class Auth(object): if len(encode_canonical_json(event.get_pdu_json())) > 65536: too_big("event") - @defer.inlineCallbacks - def check_joined_room(self, room_id, user_id, current_state=None): - """Check if the user is currently joined in the room - Args: - room_id(str): The room to check. - user_id(str): The user to check. - current_state(dict): Optional map of the current state of the room. - If provided then that map is used to check whether they are a - member of the room. Otherwise the current membership is - loaded from the database. - Raises: - AuthError if the user is not in the room. - Returns: - A deferred membership event for the user if the user is in - the room. - """ - if current_state: - member = current_state.get( - (EventTypes.Member, user_id), - None - ) - else: - member = yield self.state.get_current_state( - room_id=room_id, - event_type=EventTypes.Member, - state_key=user_id - ) - - self._check_joined_room(member, user_id, room_id) - defer.returnValue(member) - - @defer.inlineCallbacks - def check_user_was_in_room(self, room_id, user_id): - """Check if the user was in the room at some point. - Args: - room_id(str): The room to check. - user_id(str): The user to check. - Raises: - AuthError if the user was never in the room. - Returns: - A deferred membership event for the user if the user was in the - room. This will be the join event if they are currently joined to - the room. This will be the leave event if they have left the room. - """ - member = yield self.state.get_current_state( - room_id=room_id, - event_type=EventTypes.Member, - state_key=user_id - ) - membership = member.membership if member else None - - if membership not in (Membership.JOIN, Membership.LEAVE): - raise AuthError(403, "User %s not in room %s" % ( - user_id, room_id - )) - - if membership == Membership.LEAVE: - forgot = yield self.store.did_forget(user_id, room_id) - if forgot: - raise AuthError(403, "User %s not in room %s" % ( - user_id, room_id - )) - - defer.returnValue(member) - - @defer.inlineCallbacks - def check_host_in_room(self, room_id, host): - with Measure(self.clock, "check_host_in_room"): - latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - - logger.info("calling resolve_state_groups from check_host_in_room") - entry = yield self.state.resolve_state_groups( - room_id, latest_event_ids - ) - - ret = yield self.store.is_host_joined( - room_id, host, entry.state_group, entry.state - ) - defer.returnValue(ret) - - def check_event_sender_in_room(self, event, auth_events): - key = (EventTypes.Member, event.user_id, ) - member_event = auth_events.get(key) - - return self._check_joined_room( - member_event, - event.user_id, - event.room_id - ) - - def _check_joined_room(self, member, user_id, room_id): - if not member or member.membership != Membership.JOIN: - raise AuthError(403, "User %s not in room %s (%s)" % ( - user_id, room_id, repr(member) - )) - - def can_federate(self, event, auth_events): + @staticmethod + def can_federate(event, auth_events): creation_event = auth_events.get((EventTypes.Create, "")) return creation_event.content.get("m.federate", True) is True - @log_function - def is_membership_change_allowed(self, event, auth_events): + @staticmethod + def is_membership_change_allowed(event, auth_events): membership = event.content["membership"] # Check if this is the room creator joining: @@ -339,7 +222,7 @@ class Auth(object): creating_domain = get_domain_from_id(event.room_id) target_domain = get_domain_from_id(target_user_id) if creating_domain != target_domain: - if not self.can_federate(event, auth_events): + if not Auther.can_federate(event, auth_events): raise AuthError( 403, "This room has been marked as unfederatable." @@ -368,13 +251,13 @@ class Auth(object): else: join_rule = JoinRules.INVITE - user_level = self._get_user_power_level(event.user_id, auth_events) - target_level = self._get_user_power_level( + user_level = Auther._get_user_power_level(event.user_id, auth_events) + target_level = Auther._get_user_power_level( target_user_id, auth_events ) # FIXME (erikj): What should we do here as the default? - ban_level = self._get_named_level(auth_events, "ban", 50) + ban_level = Auther._get_named_level(auth_events, "ban", 50) logger.debug( "is_membership_change_allowed: %s", @@ -391,7 +274,7 @@ class Auth(object): ) if Membership.INVITE == membership and "third_party_invite" in event.content: - if not self._verify_third_party_invite(event, auth_events): + if not Auther._verify_third_party_invite(event, auth_events): raise AuthError(403, "You are not invited to this room.") if target_banned: raise AuthError( @@ -424,7 +307,7 @@ class Auth(object): raise AuthError(403, "%s is already in the room." % target_user_id) else: - invite_level = self._get_named_level(auth_events, "invite", 0) + invite_level = Auther._get_named_level(auth_events, "invite", 0) if user_level < invite_level: raise AuthError( @@ -454,7 +337,7 @@ class Auth(object): 403, "You cannot unban user &s." % (target_user_id,) ) elif target_user_id != event.user_id: - kick_level = self._get_named_level(auth_events, "kick", 50) + kick_level = Auther._get_named_level(auth_events, "kick", 50) if user_level < kick_level or user_level <= target_level: raise AuthError( @@ -468,7 +351,234 @@ class Auth(object): return True - def _verify_third_party_invite(self, event, auth_events): + @staticmethod + def check_event_sender_in_room(event, auth_events): + key = (EventTypes.Member, event.user_id, ) + member_event = auth_events.get(key) + + return Auther._check_joined_room( + member_event, + event.user_id, + event.room_id + ) + + @staticmethod + def _check_joined_room(member, user_id, room_id): + if not member or member.membership != Membership.JOIN: + raise AuthError(403, "User %s not in room %s (%s)" % ( + user_id, room_id, repr(member) + )) + + @staticmethod + def _get_send_level(etype, state_key, auth_events): + key = (EventTypes.PowerLevels, "", ) + send_level_event = auth_events.get(key) + send_level = None + if send_level_event: + send_level = send_level_event.content.get("events", {}).get( + etype + ) + if send_level is None: + if state_key is not None: + send_level = send_level_event.content.get( + "state_default", 50 + ) + else: + send_level = send_level_event.content.get( + "events_default", 0 + ) + + if send_level: + send_level = int(send_level) + else: + send_level = 0 + + return send_level + + @staticmethod + def _can_send_event(event, auth_events): + send_level = Auther._get_send_level( + event.type, event.get("state_key", None), auth_events + ) + user_level = Auther._get_user_power_level(event.user_id, auth_events) + + if user_level < send_level: + raise AuthError( + 403, + "You don't have permission to post that to the room. " + + "user_level (%d) < send_level (%d)" % (user_level, send_level) + ) + + # Check state_key + if hasattr(event, "state_key"): + if event.state_key.startswith("@"): + if event.state_key != event.user_id: + raise AuthError( + 403, + "You are not allowed to set others state" + ) + + return True + + @staticmethod + def check_redaction(event, auth_events): + """Check whether the event sender is allowed to redact the target event. + + Returns: + True if the the sender is allowed to redact the target event if the + target event was created by them. + False if the sender is allowed to redact the target event with no + further checks. + + Raises: + AuthError if the event sender is definitely not allowed to redact + the target event. + """ + user_level = Auther._get_user_power_level(event.user_id, auth_events) + + redact_level = Auther._get_named_level(auth_events, "redact", 50) + + if user_level >= redact_level: + return False + + redacter_domain = get_domain_from_id(event.event_id) + redactee_domain = get_domain_from_id(event.redacts) + if redacter_domain == redactee_domain: + return True + + raise AuthError( + 403, + "You don't have permission to redact events" + ) + + @staticmethod + def _check_power_levels(event, auth_events): + user_list = event.content.get("users", {}) + # Validate users + for k, v in user_list.items(): + try: + UserID.from_string(k) + except: + raise SynapseError(400, "Not a valid user_id: %s" % (k,)) + + try: + int(v) + except: + raise SynapseError(400, "Not a valid power level: %s" % (v,)) + + key = (event.type, event.state_key, ) + current_state = auth_events.get(key) + + if not current_state: + return + + user_level = Auther._get_user_power_level(event.user_id, auth_events) + + # Check other levels: + levels_to_check = [ + ("users_default", None), + ("events_default", None), + ("state_default", None), + ("ban", None), + ("redact", None), + ("kick", None), + ("invite", None), + ] + + old_list = current_state.content.get("users") + for user in set(old_list.keys() + user_list.keys()): + levels_to_check.append( + (user, "users") + ) + + old_list = current_state.content.get("events") + new_list = event.content.get("events") + for ev_id in set(old_list.keys() + new_list.keys()): + levels_to_check.append( + (ev_id, "events") + ) + + old_state = current_state.content + new_state = event.content + + for level_to_check, dir in levels_to_check: + old_loc = old_state + new_loc = new_state + if dir: + old_loc = old_loc.get(dir, {}) + new_loc = new_loc.get(dir, {}) + + if level_to_check in old_loc: + old_level = int(old_loc[level_to_check]) + else: + old_level = None + + if level_to_check in new_loc: + new_level = int(new_loc[level_to_check]) + else: + new_level = None + + if new_level is not None and old_level is not None: + if new_level == old_level: + continue + + if dir == "users" and level_to_check != event.user_id: + if old_level == user_level: + raise AuthError( + 403, + "You don't have permission to remove ops level equal " + "to your own" + ) + + if old_level > user_level or new_level > user_level: + raise AuthError( + 403, + "You don't have permission to add ops level greater " + "than your own" + ) + + @staticmethod + def _get_power_level_event(auth_events): + key = (EventTypes.PowerLevels, "", ) + return auth_events.get(key) + + @staticmethod + def _get_user_power_level(user_id, auth_events): + power_level_event = Auther._get_power_level_event(auth_events) + + if power_level_event: + level = power_level_event.content.get("users", {}).get(user_id) + if not level: + level = power_level_event.content.get("users_default", 0) + + if level is None: + return 0 + else: + return int(level) + else: + key = (EventTypes.Create, "", ) + create_event = auth_events.get(key) + if (create_event is not None and + create_event.content["creator"] == user_id): + return 100 + else: + return 0 + + @staticmethod + def _get_named_level(auth_events, name, default): + power_level_event = Auther._get_power_level_event(auth_events) + + if not power_level_event: + return default + + level = power_level_event.content.get(name, None) + if level is not None: + return int(level) + else: + return default + + @staticmethod + def _verify_third_party_invite(event, auth_events): """ Validates that the invite event is authorized by a previous third-party invite. @@ -504,84 +614,197 @@ class Auth(object): if invite_event.sender != event.sender: return False - if event.user_id != invite_event.user_id: - return False + if event.user_id != invite_event.user_id: + return False + + if signed["mxid"] != event.state_key: + return False + if signed["token"] != token: + return False + + for public_key_object in Auther.get_public_keys(invite_event): + public_key = public_key_object["public_key"] + try: + for server, signature_block in signed["signatures"].items(): + for key_name, encoded_signature in signature_block.items(): + if not key_name.startswith("ed25519:"): + continue + verify_key = decode_verify_key_bytes( + key_name, + decode_base64(public_key) + ) + verify_signed_json(signed, server, verify_key) + + # We got the public key from the invite, so we know that the + # correct server signed the signed bundle. + # The caller is responsible for checking that the signing + # server has not revoked that public key. + return True + except (KeyError, SignatureVerifyException,): + continue + return False + + @staticmethod + def get_public_keys(invite_event): + public_keys = [] + if "public_key" in invite_event.content: + o = { + "public_key": invite_event.content["public_key"], + } + if "key_validity_url" in invite_event.content: + o["key_validity_url"] = invite_event.content["key_validity_url"] + public_keys.append(o) + public_keys.extend(invite_event.content.get("public_keys", [])) + return public_keys + + +class Auth(object): + """ + FIXME: This class contains a mix of functions for authenticating users + of our client-server API and authenticating events added to room graphs. + """ + def __init__(self, hs): + self.hs = hs + self.clock = hs.get_clock() + self.store = hs.get_datastore() + self.state = hs.get_state_handler() + self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 + + @defer.inlineCallbacks + def check_from_context(self, event, context, do_sig_check=True): + auth_events_ids = yield self.compute_auth_events( + event, context.prev_state_ids, for_verification=True, + ) + auth_events = yield self.store.get_events(auth_events_ids) + auth_events = { + (e.type, e.state_key): e for e in auth_events.values() + } + self.check(event, auth_events=auth_events, do_sig_check=do_sig_check) + + def check(self, event, auth_events, do_sig_check=True): + """ Checks if this event is correctly authed. + + Args: + event: the event being checked. + auth_events (dict: event-key -> event): the existing room state. + + + Returns: + True if the auth checks pass. + """ + with Measure(self.clock, "auth.check"): + Auther.check(event, auth_events, do_sig_check=do_sig_check) + + def check_size_limits(self, event): + def too_big(field): + raise EventSizeError("%s too large" % (field,)) + + if len(event.user_id) > 255: + too_big("user_id") + if len(event.room_id) > 255: + too_big("room_id") + if event.is_state() and len(event.state_key) > 255: + too_big("state_key") + if len(event.type) > 255: + too_big("type") + if len(event.event_id) > 255: + too_big("event_id") + if len(encode_canonical_json(event.get_pdu_json())) > 65536: + too_big("event") + + @defer.inlineCallbacks + def check_joined_room(self, room_id, user_id, current_state=None): + """Check if the user is currently joined in the room + Args: + room_id(str): The room to check. + user_id(str): The user to check. + current_state(dict): Optional map of the current state of the room. + If provided then that map is used to check whether they are a + member of the room. Otherwise the current membership is + loaded from the database. + Raises: + AuthError if the user is not in the room. + Returns: + A deferred membership event for the user if the user is in + the room. + """ + if current_state: + member = current_state.get( + (EventTypes.Member, user_id), + None + ) + else: + member = yield self.state.get_current_state( + room_id=room_id, + event_type=EventTypes.Member, + state_key=user_id + ) + + self._check_joined_room(member, user_id, room_id) + defer.returnValue(member) - if signed["mxid"] != event.state_key: - return False - if signed["token"] != token: - return False + @defer.inlineCallbacks + def check_user_was_in_room(self, room_id, user_id): + """Check if the user was in the room at some point. + Args: + room_id(str): The room to check. + user_id(str): The user to check. + Raises: + AuthError if the user was never in the room. + Returns: + A deferred membership event for the user if the user was in the + room. This will be the join event if they are currently joined to + the room. This will be the leave event if they have left the room. + """ + member = yield self.state.get_current_state( + room_id=room_id, + event_type=EventTypes.Member, + state_key=user_id + ) + membership = member.membership if member else None - for public_key_object in self.get_public_keys(invite_event): - public_key = public_key_object["public_key"] - try: - for server, signature_block in signed["signatures"].items(): - for key_name, encoded_signature in signature_block.items(): - if not key_name.startswith("ed25519:"): - continue - verify_key = decode_verify_key_bytes( - key_name, - decode_base64(public_key) - ) - verify_signed_json(signed, server, verify_key) + if membership not in (Membership.JOIN, Membership.LEAVE): + raise AuthError(403, "User %s not in room %s" % ( + user_id, room_id + )) - # We got the public key from the invite, so we know that the - # correct server signed the signed bundle. - # The caller is responsible for checking that the signing - # server has not revoked that public key. - return True - except (KeyError, SignatureVerifyException,): - continue - return False + if membership == Membership.LEAVE: + forgot = yield self.store.did_forget(user_id, room_id) + if forgot: + raise AuthError(403, "User %s not in room %s" % ( + user_id, room_id + )) - def get_public_keys(self, invite_event): - public_keys = [] - if "public_key" in invite_event.content: - o = { - "public_key": invite_event.content["public_key"], - } - if "key_validity_url" in invite_event.content: - o["key_validity_url"] = invite_event.content["key_validity_url"] - public_keys.append(o) - public_keys.extend(invite_event.content.get("public_keys", [])) - return public_keys + defer.returnValue(member) - def _get_power_level_event(self, auth_events): - key = (EventTypes.PowerLevels, "", ) - return auth_events.get(key) + @defer.inlineCallbacks + def check_host_in_room(self, room_id, host): + with Measure(self.clock, "check_host_in_room"): + latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - def _get_user_power_level(self, user_id, auth_events): - power_level_event = self._get_power_level_event(auth_events) + logger.info("calling resolve_state_groups from check_host_in_room") + entry = yield self.state.resolve_state_groups( + room_id, latest_event_ids + ) - if power_level_event: - level = power_level_event.content.get("users", {}).get(user_id) - if not level: - level = power_level_event.content.get("users_default", 0) + ret = yield self.store.is_host_joined( + room_id, host, entry.state_group, entry.state + ) + defer.returnValue(ret) - if level is None: - return 0 - else: - return int(level) - else: - key = (EventTypes.Create, "", ) - create_event = auth_events.get(key) - if (create_event is not None and - create_event.content["creator"] == user_id): - return 100 - else: - return 0 + def _check_joined_room(self, member, user_id, room_id): + if not member or member.membership != Membership.JOIN: + raise AuthError(403, "User %s not in room %s (%s)" % ( + user_id, room_id, repr(member) + )) - def _get_named_level(self, auth_events, name, default): - power_level_event = self._get_power_level_event(auth_events) + def can_federate(self, event, auth_events): + creation_event = auth_events.get((EventTypes.Create, "")) - if not power_level_event: - return default + return creation_event.content.get("m.federate", True) is True - level = power_level_event.content.get(name, None) - if level is not None: - return int(level) - else: - return default + def get_public_keys(self, invite_event): + return Auther.get_public_keys(invite_event) @defer.inlineCallbacks def get_user_by_req(self, request, allow_guest=False, rights="access"): @@ -975,54 +1198,7 @@ class Auth(object): defer.returnValue(auth_ids) def _get_send_level(self, etype, state_key, auth_events): - key = (EventTypes.PowerLevels, "", ) - send_level_event = auth_events.get(key) - send_level = None - if send_level_event: - send_level = send_level_event.content.get("events", {}).get( - etype - ) - if send_level is None: - if state_key is not None: - send_level = send_level_event.content.get( - "state_default", 50 - ) - else: - send_level = send_level_event.content.get( - "events_default", 0 - ) - - if send_level: - send_level = int(send_level) - else: - send_level = 0 - - return send_level - - @log_function - def _can_send_event(self, event, auth_events): - send_level = self._get_send_level( - event.type, event.get("state_key", None), auth_events - ) - user_level = self._get_user_power_level(event.user_id, auth_events) - - if user_level < send_level: - raise AuthError( - 403, - "You don't have permission to post that to the room. " + - "user_level (%d) < send_level (%d)" % (user_level, send_level) - ) - - # Check state_key - if hasattr(event, "state_key"): - if event.state_key.startswith("@"): - if event.state_key != event.user_id: - raise AuthError( - 403, - "You are not allowed to set others state" - ) - - return True + return Auther._get_send_level(etype, state_key, auth_events) def check_redaction(self, event, auth_events): """Check whether the event sender is allowed to redact the target event. @@ -1037,107 +1213,7 @@ class Auth(object): AuthError if the event sender is definitely not allowed to redact the target event. """ - user_level = self._get_user_power_level(event.user_id, auth_events) - - redact_level = self._get_named_level(auth_events, "redact", 50) - - if user_level >= redact_level: - return False - - redacter_domain = get_domain_from_id(event.event_id) - redactee_domain = get_domain_from_id(event.redacts) - if redacter_domain == redactee_domain: - return True - - raise AuthError( - 403, - "You don't have permission to redact events" - ) - - def _check_power_levels(self, event, auth_events): - user_list = event.content.get("users", {}) - # Validate users - for k, v in user_list.items(): - try: - UserID.from_string(k) - except: - raise SynapseError(400, "Not a valid user_id: %s" % (k,)) - - try: - int(v) - except: - raise SynapseError(400, "Not a valid power level: %s" % (v,)) - - key = (event.type, event.state_key, ) - current_state = auth_events.get(key) - - if not current_state: - return - - user_level = self._get_user_power_level(event.user_id, auth_events) - - # Check other levels: - levels_to_check = [ - ("users_default", None), - ("events_default", None), - ("state_default", None), - ("ban", None), - ("redact", None), - ("kick", None), - ("invite", None), - ] - - old_list = current_state.content.get("users") - for user in set(old_list.keys() + user_list.keys()): - levels_to_check.append( - (user, "users") - ) - - old_list = current_state.content.get("events") - new_list = event.content.get("events") - for ev_id in set(old_list.keys() + new_list.keys()): - levels_to_check.append( - (ev_id, "events") - ) - - old_state = current_state.content - new_state = event.content - - for level_to_check, dir in levels_to_check: - old_loc = old_state - new_loc = new_state - if dir: - old_loc = old_loc.get(dir, {}) - new_loc = new_loc.get(dir, {}) - - if level_to_check in old_loc: - old_level = int(old_loc[level_to_check]) - else: - old_level = None - - if level_to_check in new_loc: - new_level = int(new_loc[level_to_check]) - else: - new_level = None - - if new_level is not None and old_level is not None: - if new_level == old_level: - continue - - if dir == "users" and level_to_check != event.user_id: - if old_level == user_level: - raise AuthError( - 403, - "You don't have permission to remove ops level equal " - "to your own" - ) - - if old_level > user_level or new_level > user_level: - raise AuthError( - 403, - "You don't have permission to add ops level greater " - "than your own" - ) + return Auther.check_redaction(event, auth_events) @defer.inlineCallbacks def check_can_change_room_list(self, room_id, user): @@ -1167,10 +1243,10 @@ class Auth(object): if power_level_event: auth_events[(EventTypes.PowerLevels, "")] = power_level_event - send_level = self._get_send_level( + send_level = Auther._get_send_level( EventTypes.Aliases, "", auth_events ) - user_level = self._get_user_power_level(user_id, auth_events) + user_level = Auther._get_user_power_level(user_id, auth_events) if user_level < send_level: raise AuthError( -- cgit 1.4.1 From 7b62d0bc70903d3c0cf49d67db31fa3682e33a55 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 11 Jan 2017 10:57:03 +0000 Subject: Add missing None check --- synapse/handlers/room_member.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 8a76469b77..b2806555cf 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -232,11 +232,12 @@ class RoomMemberHandler(BaseHandler): errcode=Codes.BAD_STATE ) - same_content = content == old_state.content - same_membership = old_membership == effective_membership_state - same_sender = requester.user.to_string() == old_state.sender - if same_sender and same_membership and same_content: - defer.returnValue(old_state) + if old_state: + same_content = content == old_state.content + same_membership = old_membership == effective_membership_state + same_sender = requester.user.to_string() == old_state.sender + if same_sender and same_membership and same_content: + defer.returnValue(old_state) is_host_in_room = yield self._is_host_in_room(current_state_ids) -- cgit 1.4.1 From bf5c9706d9053ffad05fc12eca71b8d441fa9306 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 12 Jan 2017 10:32:52 +0000 Subject: Remove full_twisted_stacktraces option The debug 'full_twisted_stacktraces' flag caused synapse to rewrite twisted deferreds to always fire the callback on the next reactor tick. This was to force the deferred to always store the stacktraces on exceptions, and thus be more likely to have a full stacktrace when it reaches the final error handlers and gets printed to the logs. Dynamically rewriting things is generally bad, and in particular this change violates assumptions of various bits of Twisted. This wouldn't necessarily be so bad, but it turns out this option has been turned on on some production servers. Turning the option can cause e.g. #1778. For now, lets just entirely nuke this option. --- synapse/config/logger.py | 8 ------ synapse/util/debug.py | 71 ------------------------------------------------ 2 files changed, 79 deletions(-) delete mode 100644 synapse/util/debug.py diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 63e69a7e0c..77ded0ad25 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -22,7 +22,6 @@ import yaml from string import Template import os import signal -from synapse.util.debug import debug_deferreds DEFAULT_LOG_CONFIG = Template(""" @@ -71,8 +70,6 @@ class LoggingConfig(Config): self.verbosity = config.get("verbose", 0) self.log_config = self.abspath(config.get("log_config")) self.log_file = self.abspath(config.get("log_file")) - if config.get("full_twisted_stacktraces"): - debug_deferreds() def default_config(self, config_dir_path, server_name, **kwargs): log_file = self.abspath("homeserver.log") @@ -88,11 +85,6 @@ class LoggingConfig(Config): # A yaml python logging config file log_config: "%(log_config)s" - - # Stop twisted from discarding the stack traces of exceptions in - # deferreds by waiting a reactor tick before running a deferred's - # callbacks. - # full_twisted_stacktraces: true """ % locals() def read_arguments(self, args): diff --git a/synapse/util/debug.py b/synapse/util/debug.py deleted file mode 100644 index dc49162e6a..0000000000 --- a/synapse/util/debug.py +++ /dev/null @@ -1,71 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from twisted.internet import defer, reactor -from functools import wraps -from synapse.util.logcontext import LoggingContext, PreserveLoggingContext - - -def debug_deferreds(): - """Cause all deferreds to wait for a reactor tick before running their - callbacks. This increases the chance of getting a stack trace out of - a defer.inlineCallback since the code waiting on the deferred will get - a chance to add an errback before the deferred runs.""" - - # Helper method for retrieving and restoring the current logging context - # around a callback. - def with_logging_context(fn): - context = LoggingContext.current_context() - - def restore_context_callback(x): - with PreserveLoggingContext(context): - return fn(x) - - return restore_context_callback - - # We are going to modify the __init__ method of defer.Deferred so we - # need to get a copy of the old method so we can still call it. - old__init__ = defer.Deferred.__init__ - - # We need to create a deferred to bounce the callbacks through the reactor - # but we don't want to add a callback when we create that deferred so we - # we create a new type of deferred that uses the old __init__ method. - # This is safe as long as the old __init__ method doesn't invoke an - # __init__ using super. - class Bouncer(defer.Deferred): - __init__ = old__init__ - - # We'll add this as a callback to all Deferreds. Twisted will wait until - # the bouncer deferred resolves before calling the callbacks of the - # original deferred. - def bounce_callback(x): - bouncer = Bouncer() - reactor.callLater(0, with_logging_context(bouncer.callback), x) - return bouncer - - # We'll add this as an errback to all Deferreds. Twisted will wait until - # the bouncer deferred resolves before calling the errbacks of the - # original deferred. - def bounce_errback(x): - bouncer = Bouncer() - reactor.callLater(0, with_logging_context(bouncer.errback), x) - return bouncer - - @wraps(old__init__) - def new__init__(self, *args, **kargs): - old__init__(self, *args, **kargs) - self.addCallbacks(bounce_callback, bounce_errback) - - defer.Deferred.__init__ = new__init__ -- cgit 1.4.1 From ebf94aff8d8cf6a6ed187b2c8e6aaa69f3912a48 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 12 Jan 2017 17:19:47 +0000 Subject: Fix spurious Unhandled Error log lines --- synapse/rest/client/transactions.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 351170edbc..efa77b8c51 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -86,7 +86,11 @@ class HttpTransactionCache(object): pass # execute the function instead. deferred = fn(*args, **kwargs) - observable = ObservableDeferred(deferred) + + # We don't add an errback to the raw deferred, so we ask ObservableDeferred + # to swallow the error. This is fine as the error will still be reported + # to the observers. + observable = ObservableDeferred(deferred, consumeErrors=True) self.transactions[txn_key] = (observable, self.clock.time_msec()) return observable.observe() -- cgit 1.4.1 From 6f5e41e420d3a928c59841640610e1ad7756121c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 13 Jan 2017 12:52:11 +0000 Subject: README.rst: fix formatting Fix formatting blooper introduced in https://github.com/matrix-org/synapse/pull/1672 :/ --- README.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/README.rst b/README.rst index ba21c52ae7..77e0b470a3 100644 --- a/README.rst +++ b/README.rst @@ -138,6 +138,7 @@ Installing prerequisites on openSUSE:: python-devel libffi-devel libopenssl-devel libjpeg62-devel Installing prerequisites on OpenBSD:: + doas pkg_add python libffi py-pip py-setuptools sqlite3 py-virtualenv \ libxslt -- cgit 1.4.1 From 8b2fa382568373573d3b1d520e8ebc2ef39e2935 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 13 Jan 2017 15:07:32 +0000 Subject: Split event auth code into seperate module --- synapse/api/auth.py | 654 +------------------------------------------------- synapse/event_auth.py | 641 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 650 insertions(+), 645 deletions(-) create mode 100644 synapse/event_auth.py diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 5e2b89c324..b781d41a66 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -16,16 +16,13 @@ import logging import pymacaroons -from canonicaljson import encode_canonical_json -from signedjson.key import decode_verify_key_bytes -from signedjson.sign import verify_signed_json, SignatureVerifyException from twisted.internet import defer -from unpaddedbase64 import decode_base64 import synapse.types +from synapse import event_auth from synapse.api.constants import EventTypes, Membership, JoinRules -from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError -from synapse.types import UserID, get_domain_from_id +from synapse.api.errors import AuthError, Codes +from synapse.types import UserID from synapse.util.logcontext import preserve_context_over_fn from synapse.util.metrics import Measure @@ -42,622 +39,6 @@ AuthEventTypes = ( GUEST_DEVICE_ID = "guest_device" -class Auther(object): - @staticmethod - def check(event, auth_events, do_sig_check=True): - """ Checks if this event is correctly authed. - - Args: - event: the event being checked. - auth_events (dict: event-key -> event): the existing room state. - - - Returns: - True if the auth checks pass. - """ - Auther.check_size_limits(event) - - if not hasattr(event, "room_id"): - raise AuthError(500, "Event has no room_id: %s" % event) - - if do_sig_check: - sender_domain = get_domain_from_id(event.sender) - event_id_domain = get_domain_from_id(event.event_id) - - is_invite_via_3pid = ( - event.type == EventTypes.Member - and event.membership == Membership.INVITE - and "third_party_invite" in event.content - ) - - # Check the sender's domain has signed the event - if not event.signatures.get(sender_domain): - # We allow invites via 3pid to have a sender from a different - # HS, as the sender must match the sender of the original - # 3pid invite. This is checked further down with the - # other dedicated membership checks. - if not is_invite_via_3pid: - raise AuthError(403, "Event not signed by sender's server") - - # Check the event_id's domain has signed the event - if not event.signatures.get(event_id_domain): - raise AuthError(403, "Event not signed by sending server") - - if auth_events is None: - # Oh, we don't know what the state of the room was, so we - # are trusting that this is allowed (at least for now) - logger.warn("Trusting event: %s", event.event_id) - return True - - if event.type == EventTypes.Create: - room_id_domain = get_domain_from_id(event.room_id) - if room_id_domain != sender_domain: - raise AuthError( - 403, - "Creation event's room_id domain does not match sender's" - ) - # FIXME - return True - - creation_event = auth_events.get((EventTypes.Create, ""), None) - - if not creation_event: - raise SynapseError( - 403, - "Room %r does not exist" % (event.room_id,) - ) - - creating_domain = get_domain_from_id(event.room_id) - originating_domain = get_domain_from_id(event.sender) - if creating_domain != originating_domain: - if not Auther.can_federate(event, auth_events): - raise AuthError( - 403, - "This room has been marked as unfederatable." - ) - - # FIXME: Temp hack - if event.type == EventTypes.Aliases: - if not event.is_state(): - raise AuthError( - 403, - "Alias event must be a state event", - ) - if not event.state_key: - raise AuthError( - 403, - "Alias event must have non-empty state_key" - ) - sender_domain = get_domain_from_id(event.sender) - if event.state_key != sender_domain: - raise AuthError( - 403, - "Alias event's state_key does not match sender's domain" - ) - return True - - logger.debug( - "Auth events: %s", - [a.event_id for a in auth_events.values()] - ) - - if event.type == EventTypes.Member: - allowed = Auther.is_membership_change_allowed( - event, auth_events - ) - if allowed: - logger.debug("Allowing! %s", event) - else: - logger.debug("Denying! %s", event) - return allowed - - Auther.check_event_sender_in_room(event, auth_events) - - # Special case to allow m.room.third_party_invite events wherever - # a user is allowed to issue invites. Fixes - # https://github.com/vector-im/vector-web/issues/1208 hopefully - if event.type == EventTypes.ThirdPartyInvite: - user_level = Auther._get_user_power_level(event.user_id, auth_events) - invite_level = Auther._get_named_level(auth_events, "invite", 0) - - if user_level < invite_level: - raise AuthError( - 403, ( - "You cannot issue a third party invite for %s." % - (event.content.display_name,) - ) - ) - else: - return True - - Auther._can_send_event(event, auth_events) - - if event.type == EventTypes.PowerLevels: - Auther._check_power_levels(event, auth_events) - - if event.type == EventTypes.Redaction: - Auther.check_redaction(event, auth_events) - - logger.debug("Allowing! %s", event) - - @staticmethod - def check_size_limits(event): - def too_big(field): - raise EventSizeError("%s too large" % (field,)) - - if len(event.user_id) > 255: - too_big("user_id") - if len(event.room_id) > 255: - too_big("room_id") - if event.is_state() and len(event.state_key) > 255: - too_big("state_key") - if len(event.type) > 255: - too_big("type") - if len(event.event_id) > 255: - too_big("event_id") - if len(encode_canonical_json(event.get_pdu_json())) > 65536: - too_big("event") - - @staticmethod - def can_federate(event, auth_events): - creation_event = auth_events.get((EventTypes.Create, "")) - - return creation_event.content.get("m.federate", True) is True - - @staticmethod - def is_membership_change_allowed(event, auth_events): - membership = event.content["membership"] - - # Check if this is the room creator joining: - if len(event.prev_events) == 1 and Membership.JOIN == membership: - # Get room creation event: - key = (EventTypes.Create, "", ) - create = auth_events.get(key) - if create and event.prev_events[0][0] == create.event_id: - if create.content["creator"] == event.state_key: - return True - - target_user_id = event.state_key - - creating_domain = get_domain_from_id(event.room_id) - target_domain = get_domain_from_id(target_user_id) - if creating_domain != target_domain: - if not Auther.can_federate(event, auth_events): - raise AuthError( - 403, - "This room has been marked as unfederatable." - ) - - # get info about the caller - key = (EventTypes.Member, event.user_id, ) - caller = auth_events.get(key) - - caller_in_room = caller and caller.membership == Membership.JOIN - caller_invited = caller and caller.membership == Membership.INVITE - - # get info about the target - key = (EventTypes.Member, target_user_id, ) - target = auth_events.get(key) - - target_in_room = target and target.membership == Membership.JOIN - target_banned = target and target.membership == Membership.BAN - - key = (EventTypes.JoinRules, "", ) - join_rule_event = auth_events.get(key) - if join_rule_event: - join_rule = join_rule_event.content.get( - "join_rule", JoinRules.INVITE - ) - else: - join_rule = JoinRules.INVITE - - user_level = Auther._get_user_power_level(event.user_id, auth_events) - target_level = Auther._get_user_power_level( - target_user_id, auth_events - ) - - # FIXME (erikj): What should we do here as the default? - ban_level = Auther._get_named_level(auth_events, "ban", 50) - - logger.debug( - "is_membership_change_allowed: %s", - { - "caller_in_room": caller_in_room, - "caller_invited": caller_invited, - "target_banned": target_banned, - "target_in_room": target_in_room, - "membership": membership, - "join_rule": join_rule, - "target_user_id": target_user_id, - "event.user_id": event.user_id, - } - ) - - if Membership.INVITE == membership and "third_party_invite" in event.content: - if not Auther._verify_third_party_invite(event, auth_events): - raise AuthError(403, "You are not invited to this room.") - if target_banned: - raise AuthError( - 403, "%s is banned from the room" % (target_user_id,) - ) - return True - - if Membership.JOIN != membership: - if (caller_invited - and Membership.LEAVE == membership - and target_user_id == event.user_id): - return True - - if not caller_in_room: # caller isn't joined - raise AuthError( - 403, - "%s not in room %s." % (event.user_id, event.room_id,) - ) - - if Membership.INVITE == membership: - # TODO (erikj): We should probably handle this more intelligently - # PRIVATE join rules. - - # Invites are valid iff caller is in the room and target isn't. - if target_banned: - raise AuthError( - 403, "%s is banned from the room" % (target_user_id,) - ) - elif target_in_room: # the target is already in the room. - raise AuthError(403, "%s is already in the room." % - target_user_id) - else: - invite_level = Auther._get_named_level(auth_events, "invite", 0) - - if user_level < invite_level: - raise AuthError( - 403, "You cannot invite user %s." % target_user_id - ) - elif Membership.JOIN == membership: - # Joins are valid iff caller == target and they were: - # invited: They are accepting the invitation - # joined: It's a NOOP - if event.user_id != target_user_id: - raise AuthError(403, "Cannot force another user to join.") - elif target_banned: - raise AuthError(403, "You are banned from this room") - elif join_rule == JoinRules.PUBLIC: - pass - elif join_rule == JoinRules.INVITE: - if not caller_in_room and not caller_invited: - raise AuthError(403, "You are not invited to this room.") - else: - # TODO (erikj): may_join list - # TODO (erikj): private rooms - raise AuthError(403, "You are not allowed to join this room") - elif Membership.LEAVE == membership: - # TODO (erikj): Implement kicks. - if target_banned and user_level < ban_level: - raise AuthError( - 403, "You cannot unban user &s." % (target_user_id,) - ) - elif target_user_id != event.user_id: - kick_level = Auther._get_named_level(auth_events, "kick", 50) - - if user_level < kick_level or user_level <= target_level: - raise AuthError( - 403, "You cannot kick user %s." % target_user_id - ) - elif Membership.BAN == membership: - if user_level < ban_level or user_level <= target_level: - raise AuthError(403, "You don't have permission to ban") - else: - raise AuthError(500, "Unknown membership %s" % membership) - - return True - - @staticmethod - def check_event_sender_in_room(event, auth_events): - key = (EventTypes.Member, event.user_id, ) - member_event = auth_events.get(key) - - return Auther._check_joined_room( - member_event, - event.user_id, - event.room_id - ) - - @staticmethod - def _check_joined_room(member, user_id, room_id): - if not member or member.membership != Membership.JOIN: - raise AuthError(403, "User %s not in room %s (%s)" % ( - user_id, room_id, repr(member) - )) - - @staticmethod - def _get_send_level(etype, state_key, auth_events): - key = (EventTypes.PowerLevels, "", ) - send_level_event = auth_events.get(key) - send_level = None - if send_level_event: - send_level = send_level_event.content.get("events", {}).get( - etype - ) - if send_level is None: - if state_key is not None: - send_level = send_level_event.content.get( - "state_default", 50 - ) - else: - send_level = send_level_event.content.get( - "events_default", 0 - ) - - if send_level: - send_level = int(send_level) - else: - send_level = 0 - - return send_level - - @staticmethod - def _can_send_event(event, auth_events): - send_level = Auther._get_send_level( - event.type, event.get("state_key", None), auth_events - ) - user_level = Auther._get_user_power_level(event.user_id, auth_events) - - if user_level < send_level: - raise AuthError( - 403, - "You don't have permission to post that to the room. " + - "user_level (%d) < send_level (%d)" % (user_level, send_level) - ) - - # Check state_key - if hasattr(event, "state_key"): - if event.state_key.startswith("@"): - if event.state_key != event.user_id: - raise AuthError( - 403, - "You are not allowed to set others state" - ) - - return True - - @staticmethod - def check_redaction(event, auth_events): - """Check whether the event sender is allowed to redact the target event. - - Returns: - True if the the sender is allowed to redact the target event if the - target event was created by them. - False if the sender is allowed to redact the target event with no - further checks. - - Raises: - AuthError if the event sender is definitely not allowed to redact - the target event. - """ - user_level = Auther._get_user_power_level(event.user_id, auth_events) - - redact_level = Auther._get_named_level(auth_events, "redact", 50) - - if user_level >= redact_level: - return False - - redacter_domain = get_domain_from_id(event.event_id) - redactee_domain = get_domain_from_id(event.redacts) - if redacter_domain == redactee_domain: - return True - - raise AuthError( - 403, - "You don't have permission to redact events" - ) - - @staticmethod - def _check_power_levels(event, auth_events): - user_list = event.content.get("users", {}) - # Validate users - for k, v in user_list.items(): - try: - UserID.from_string(k) - except: - raise SynapseError(400, "Not a valid user_id: %s" % (k,)) - - try: - int(v) - except: - raise SynapseError(400, "Not a valid power level: %s" % (v,)) - - key = (event.type, event.state_key, ) - current_state = auth_events.get(key) - - if not current_state: - return - - user_level = Auther._get_user_power_level(event.user_id, auth_events) - - # Check other levels: - levels_to_check = [ - ("users_default", None), - ("events_default", None), - ("state_default", None), - ("ban", None), - ("redact", None), - ("kick", None), - ("invite", None), - ] - - old_list = current_state.content.get("users") - for user in set(old_list.keys() + user_list.keys()): - levels_to_check.append( - (user, "users") - ) - - old_list = current_state.content.get("events") - new_list = event.content.get("events") - for ev_id in set(old_list.keys() + new_list.keys()): - levels_to_check.append( - (ev_id, "events") - ) - - old_state = current_state.content - new_state = event.content - - for level_to_check, dir in levels_to_check: - old_loc = old_state - new_loc = new_state - if dir: - old_loc = old_loc.get(dir, {}) - new_loc = new_loc.get(dir, {}) - - if level_to_check in old_loc: - old_level = int(old_loc[level_to_check]) - else: - old_level = None - - if level_to_check in new_loc: - new_level = int(new_loc[level_to_check]) - else: - new_level = None - - if new_level is not None and old_level is not None: - if new_level == old_level: - continue - - if dir == "users" and level_to_check != event.user_id: - if old_level == user_level: - raise AuthError( - 403, - "You don't have permission to remove ops level equal " - "to your own" - ) - - if old_level > user_level or new_level > user_level: - raise AuthError( - 403, - "You don't have permission to add ops level greater " - "than your own" - ) - - @staticmethod - def _get_power_level_event(auth_events): - key = (EventTypes.PowerLevels, "", ) - return auth_events.get(key) - - @staticmethod - def _get_user_power_level(user_id, auth_events): - power_level_event = Auther._get_power_level_event(auth_events) - - if power_level_event: - level = power_level_event.content.get("users", {}).get(user_id) - if not level: - level = power_level_event.content.get("users_default", 0) - - if level is None: - return 0 - else: - return int(level) - else: - key = (EventTypes.Create, "", ) - create_event = auth_events.get(key) - if (create_event is not None and - create_event.content["creator"] == user_id): - return 100 - else: - return 0 - - @staticmethod - def _get_named_level(auth_events, name, default): - power_level_event = Auther._get_power_level_event(auth_events) - - if not power_level_event: - return default - - level = power_level_event.content.get(name, None) - if level is not None: - return int(level) - else: - return default - - @staticmethod - def _verify_third_party_invite(event, auth_events): - """ - Validates that the invite event is authorized by a previous third-party invite. - - Checks that the public key, and keyserver, match those in the third party invite, - and that the invite event has a signature issued using that public key. - - Args: - event: The m.room.member join event being validated. - auth_events: All relevant previous context events which may be used - for authorization decisions. - - Return: - True if the event fulfills the expectations of a previous third party - invite event. - """ - if "third_party_invite" not in event.content: - return False - if "signed" not in event.content["third_party_invite"]: - return False - signed = event.content["third_party_invite"]["signed"] - for key in {"mxid", "token"}: - if key not in signed: - return False - - token = signed["token"] - - invite_event = auth_events.get( - (EventTypes.ThirdPartyInvite, token,) - ) - if not invite_event: - return False - - if invite_event.sender != event.sender: - return False - - if event.user_id != invite_event.user_id: - return False - - if signed["mxid"] != event.state_key: - return False - if signed["token"] != token: - return False - - for public_key_object in Auther.get_public_keys(invite_event): - public_key = public_key_object["public_key"] - try: - for server, signature_block in signed["signatures"].items(): - for key_name, encoded_signature in signature_block.items(): - if not key_name.startswith("ed25519:"): - continue - verify_key = decode_verify_key_bytes( - key_name, - decode_base64(public_key) - ) - verify_signed_json(signed, server, verify_key) - - # We got the public key from the invite, so we know that the - # correct server signed the signed bundle. - # The caller is responsible for checking that the signing - # server has not revoked that public key. - return True - except (KeyError, SignatureVerifyException,): - continue - return False - - @staticmethod - def get_public_keys(invite_event): - public_keys = [] - if "public_key" in invite_event.content: - o = { - "public_key": invite_event.content["public_key"], - } - if "key_validity_url" in invite_event.content: - o["key_validity_url"] = invite_event.content["key_validity_url"] - public_keys.append(o) - public_keys.extend(invite_event.content.get("public_keys", [])) - return public_keys - - class Auth(object): """ FIXME: This class contains a mix of functions for authenticating users @@ -693,24 +74,7 @@ class Auth(object): True if the auth checks pass. """ with Measure(self.clock, "auth.check"): - Auther.check(event, auth_events, do_sig_check=do_sig_check) - - def check_size_limits(self, event): - def too_big(field): - raise EventSizeError("%s too large" % (field,)) - - if len(event.user_id) > 255: - too_big("user_id") - if len(event.room_id) > 255: - too_big("room_id") - if event.is_state() and len(event.state_key) > 255: - too_big("state_key") - if len(event.type) > 255: - too_big("type") - if len(event.event_id) > 255: - too_big("event_id") - if len(encode_canonical_json(event.get_pdu_json())) > 65536: - too_big("event") + event_auth.check(event, auth_events, do_sig_check=do_sig_check) @defer.inlineCallbacks def check_joined_room(self, room_id, user_id, current_state=None): @@ -804,7 +168,7 @@ class Auth(object): return creation_event.content.get("m.federate", True) is True def get_public_keys(self, invite_event): - return Auther.get_public_keys(invite_event) + return event_auth.get_public_keys(invite_event) @defer.inlineCallbacks def get_user_by_req(self, request, allow_guest=False, rights="access"): @@ -1198,7 +562,7 @@ class Auth(object): defer.returnValue(auth_ids) def _get_send_level(self, etype, state_key, auth_events): - return Auther._get_send_level(etype, state_key, auth_events) + return event_auth._get_send_level(etype, state_key, auth_events) def check_redaction(self, event, auth_events): """Check whether the event sender is allowed to redact the target event. @@ -1213,7 +577,7 @@ class Auth(object): AuthError if the event sender is definitely not allowed to redact the target event. """ - return Auther.check_redaction(event, auth_events) + return event_auth.check_redaction(event, auth_events) @defer.inlineCallbacks def check_can_change_room_list(self, room_id, user): @@ -1243,10 +607,10 @@ class Auth(object): if power_level_event: auth_events[(EventTypes.PowerLevels, "")] = power_level_event - send_level = Auther._get_send_level( + send_level = event_auth.get_send_level( EventTypes.Aliases, "", auth_events ) - user_level = Auther._get_user_power_level(user_id, auth_events) + user_level = event_auth.get_user_power_level(user_id, auth_events) if user_level < send_level: raise AuthError( diff --git a/synapse/event_auth.py b/synapse/event_auth.py new file mode 100644 index 0000000000..983d8e9a85 --- /dev/null +++ b/synapse/event_auth.py @@ -0,0 +1,641 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 - 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from canonicaljson import encode_canonical_json +from signedjson.key import decode_verify_key_bytes +from signedjson.sign import verify_signed_json, SignatureVerifyException +from unpaddedbase64 import decode_base64 + +from synapse.api.constants import EventTypes, Membership, JoinRules +from synapse.api.errors import AuthError, SynapseError, EventSizeError +from synapse.types import UserID, get_domain_from_id + +logger = logging.getLogger(__name__) + + +def check(event, auth_events, do_sig_check=True): + """ Checks if this event is correctly authed. + + Args: + event: the event being checked. + auth_events (dict: event-key -> event): the existing room state. + + + Returns: + True if the auth checks pass. + """ + _check_size_limits(event) + + if not hasattr(event, "room_id"): + raise AuthError(500, "Event has no room_id: %s" % event) + + if do_sig_check: + sender_domain = get_domain_from_id(event.sender) + event_id_domain = get_domain_from_id(event.event_id) + + is_invite_via_3pid = ( + event.type == EventTypes.Member + and event.membership == Membership.INVITE + and "third_party_invite" in event.content + ) + + # Check the sender's domain has signed the event + if not event.signatures.get(sender_domain): + # We allow invites via 3pid to have a sender from a different + # HS, as the sender must match the sender of the original + # 3pid invite. This is checked further down with the + # other dedicated membership checks. + if not is_invite_via_3pid: + raise AuthError(403, "Event not signed by sender's server") + + # Check the event_id's domain has signed the event + if not event.signatures.get(event_id_domain): + raise AuthError(403, "Event not signed by sending server") + + if auth_events is None: + # Oh, we don't know what the state of the room was, so we + # are trusting that this is allowed (at least for now) + logger.warn("Trusting event: %s", event.event_id) + return True + + if event.type == EventTypes.Create: + room_id_domain = get_domain_from_id(event.room_id) + if room_id_domain != sender_domain: + raise AuthError( + 403, + "Creation event's room_id domain does not match sender's" + ) + # FIXME + return True + + creation_event = auth_events.get((EventTypes.Create, ""), None) + + if not creation_event: + raise SynapseError( + 403, + "Room %r does not exist" % (event.room_id,) + ) + + creating_domain = get_domain_from_id(event.room_id) + originating_domain = get_domain_from_id(event.sender) + if creating_domain != originating_domain: + if not _can_federate(event, auth_events): + raise AuthError( + 403, + "This room has been marked as unfederatable." + ) + + # FIXME: Temp hack + if event.type == EventTypes.Aliases: + if not event.is_state(): + raise AuthError( + 403, + "Alias event must be a state event", + ) + if not event.state_key: + raise AuthError( + 403, + "Alias event must have non-empty state_key" + ) + sender_domain = get_domain_from_id(event.sender) + if event.state_key != sender_domain: + raise AuthError( + 403, + "Alias event's state_key does not match sender's domain" + ) + return True + + logger.debug( + "Auth events: %s", + [a.event_id for a in auth_events.values()] + ) + + if event.type == EventTypes.Member: + allowed = _is_membership_change_allowed( + event, auth_events + ) + if allowed: + logger.debug("Allowing! %s", event) + else: + logger.debug("Denying! %s", event) + return allowed + + _check_event_sender_in_room(event, auth_events) + + # Special case to allow m.room.third_party_invite events wherever + # a user is allowed to issue invites. Fixes + # https://github.com/vector-im/vector-web/issues/1208 hopefully + if event.type == EventTypes.ThirdPartyInvite: + user_level = get_user_power_level(event.user_id, auth_events) + invite_level = _get_named_level(auth_events, "invite", 0) + + if user_level < invite_level: + raise AuthError( + 403, ( + "You cannot issue a third party invite for %s." % + (event.content.display_name,) + ) + ) + else: + return True + + _can_send_event(event, auth_events) + + if event.type == EventTypes.PowerLevels: + _check_power_levels(event, auth_events) + + if event.type == EventTypes.Redaction: + check_redaction(event, auth_events) + + logger.debug("Allowing! %s", event) + + +def _check_size_limits(event): + def too_big(field): + raise EventSizeError("%s too large" % (field,)) + + if len(event.user_id) > 255: + too_big("user_id") + if len(event.room_id) > 255: + too_big("room_id") + if event.is_state() and len(event.state_key) > 255: + too_big("state_key") + if len(event.type) > 255: + too_big("type") + if len(event.event_id) > 255: + too_big("event_id") + if len(encode_canonical_json(event.get_pdu_json())) > 65536: + too_big("event") + + +def _can_federate(event, auth_events): + creation_event = auth_events.get((EventTypes.Create, "")) + + return creation_event.content.get("m.federate", True) is True + + +def _is_membership_change_allowed(event, auth_events): + membership = event.content["membership"] + + # Check if this is the room creator joining: + if len(event.prev_events) == 1 and Membership.JOIN == membership: + # Get room creation event: + key = (EventTypes.Create, "", ) + create = auth_events.get(key) + if create and event.prev_events[0][0] == create.event_id: + if create.content["creator"] == event.state_key: + return True + + target_user_id = event.state_key + + creating_domain = get_domain_from_id(event.room_id) + target_domain = get_domain_from_id(target_user_id) + if creating_domain != target_domain: + if not _can_federate(event, auth_events): + raise AuthError( + 403, + "This room has been marked as unfederatable." + ) + + # get info about the caller + key = (EventTypes.Member, event.user_id, ) + caller = auth_events.get(key) + + caller_in_room = caller and caller.membership == Membership.JOIN + caller_invited = caller and caller.membership == Membership.INVITE + + # get info about the target + key = (EventTypes.Member, target_user_id, ) + target = auth_events.get(key) + + target_in_room = target and target.membership == Membership.JOIN + target_banned = target and target.membership == Membership.BAN + + key = (EventTypes.JoinRules, "", ) + join_rule_event = auth_events.get(key) + if join_rule_event: + join_rule = join_rule_event.content.get( + "join_rule", JoinRules.INVITE + ) + else: + join_rule = JoinRules.INVITE + + user_level = get_user_power_level(event.user_id, auth_events) + target_level = get_user_power_level( + target_user_id, auth_events + ) + + # FIXME (erikj): What should we do here as the default? + ban_level = _get_named_level(auth_events, "ban", 50) + + logger.debug( + "_is_membership_change_allowed: %s", + { + "caller_in_room": caller_in_room, + "caller_invited": caller_invited, + "target_banned": target_banned, + "target_in_room": target_in_room, + "membership": membership, + "join_rule": join_rule, + "target_user_id": target_user_id, + "event.user_id": event.user_id, + } + ) + + if Membership.INVITE == membership and "third_party_invite" in event.content: + if not _verify_third_party_invite(event, auth_events): + raise AuthError(403, "You are not invited to this room.") + if target_banned: + raise AuthError( + 403, "%s is banned from the room" % (target_user_id,) + ) + return True + + if Membership.JOIN != membership: + if (caller_invited + and Membership.LEAVE == membership + and target_user_id == event.user_id): + return True + + if not caller_in_room: # caller isn't joined + raise AuthError( + 403, + "%s not in room %s." % (event.user_id, event.room_id,) + ) + + if Membership.INVITE == membership: + # TODO (erikj): We should probably handle this more intelligently + # PRIVATE join rules. + + # Invites are valid iff caller is in the room and target isn't. + if target_banned: + raise AuthError( + 403, "%s is banned from the room" % (target_user_id,) + ) + elif target_in_room: # the target is already in the room. + raise AuthError(403, "%s is already in the room." % + target_user_id) + else: + invite_level = _get_named_level(auth_events, "invite", 0) + + if user_level < invite_level: + raise AuthError( + 403, "You cannot invite user %s." % target_user_id + ) + elif Membership.JOIN == membership: + # Joins are valid iff caller == target and they were: + # invited: They are accepting the invitation + # joined: It's a NOOP + if event.user_id != target_user_id: + raise AuthError(403, "Cannot force another user to join.") + elif target_banned: + raise AuthError(403, "You are banned from this room") + elif join_rule == JoinRules.PUBLIC: + pass + elif join_rule == JoinRules.INVITE: + if not caller_in_room and not caller_invited: + raise AuthError(403, "You are not invited to this room.") + else: + # TODO (erikj): may_join list + # TODO (erikj): private rooms + raise AuthError(403, "You are not allowed to join this room") + elif Membership.LEAVE == membership: + # TODO (erikj): Implement kicks. + if target_banned and user_level < ban_level: + raise AuthError( + 403, "You cannot unban user &s." % (target_user_id,) + ) + elif target_user_id != event.user_id: + kick_level = _get_named_level(auth_events, "kick", 50) + + if user_level < kick_level or user_level <= target_level: + raise AuthError( + 403, "You cannot kick user %s." % target_user_id + ) + elif Membership.BAN == membership: + if user_level < ban_level or user_level <= target_level: + raise AuthError(403, "You don't have permission to ban") + else: + raise AuthError(500, "Unknown membership %s" % membership) + + return True + + +def _check_event_sender_in_room(event, auth_events): + key = (EventTypes.Member, event.user_id, ) + member_event = auth_events.get(key) + + return _check_joined_room( + member_event, + event.user_id, + event.room_id + ) + + +def _check_joined_room(member, user_id, room_id): + if not member or member.membership != Membership.JOIN: + raise AuthError(403, "User %s not in room %s (%s)" % ( + user_id, room_id, repr(member) + )) + + +def get_send_level(etype, state_key, auth_events): + key = (EventTypes.PowerLevels, "", ) + send_level_event = auth_events.get(key) + send_level = None + if send_level_event: + send_level = send_level_event.content.get("events", {}).get( + etype + ) + if send_level is None: + if state_key is not None: + send_level = send_level_event.content.get( + "state_default", 50 + ) + else: + send_level = send_level_event.content.get( + "events_default", 0 + ) + + if send_level: + send_level = int(send_level) + else: + send_level = 0 + + return send_level + + +def _can_send_event(event, auth_events): + send_level = get_send_level( + event.type, event.get("state_key", None), auth_events + ) + user_level = get_user_power_level(event.user_id, auth_events) + + if user_level < send_level: + raise AuthError( + 403, + "You don't have permission to post that to the room. " + + "user_level (%d) < send_level (%d)" % (user_level, send_level) + ) + + # Check state_key + if hasattr(event, "state_key"): + if event.state_key.startswith("@"): + if event.state_key != event.user_id: + raise AuthError( + 403, + "You are not allowed to set others state" + ) + + return True + + +def check_redaction(event, auth_events): + """Check whether the event sender is allowed to redact the target event. + + Returns: + True if the the sender is allowed to redact the target event if the + target event was created by them. + False if the sender is allowed to redact the target event with no + further checks. + + Raises: + AuthError if the event sender is definitely not allowed to redact + the target event. + """ + user_level = get_user_power_level(event.user_id, auth_events) + + redact_level = _get_named_level(auth_events, "redact", 50) + + if user_level >= redact_level: + return False + + redacter_domain = get_domain_from_id(event.event_id) + redactee_domain = get_domain_from_id(event.redacts) + if redacter_domain == redactee_domain: + return True + + raise AuthError( + 403, + "You don't have permission to redact events" + ) + + +def _check_power_levels(event, auth_events): + user_list = event.content.get("users", {}) + # Validate users + for k, v in user_list.items(): + try: + UserID.from_string(k) + except: + raise SynapseError(400, "Not a valid user_id: %s" % (k,)) + + try: + int(v) + except: + raise SynapseError(400, "Not a valid power level: %s" % (v,)) + + key = (event.type, event.state_key, ) + current_state = auth_events.get(key) + + if not current_state: + return + + user_level = get_user_power_level(event.user_id, auth_events) + + # Check other levels: + levels_to_check = [ + ("users_default", None), + ("events_default", None), + ("state_default", None), + ("ban", None), + ("redact", None), + ("kick", None), + ("invite", None), + ] + + old_list = current_state.content.get("users") + for user in set(old_list.keys() + user_list.keys()): + levels_to_check.append( + (user, "users") + ) + + old_list = current_state.content.get("events") + new_list = event.content.get("events") + for ev_id in set(old_list.keys() + new_list.keys()): + levels_to_check.append( + (ev_id, "events") + ) + + old_state = current_state.content + new_state = event.content + + for level_to_check, dir in levels_to_check: + old_loc = old_state + new_loc = new_state + if dir: + old_loc = old_loc.get(dir, {}) + new_loc = new_loc.get(dir, {}) + + if level_to_check in old_loc: + old_level = int(old_loc[level_to_check]) + else: + old_level = None + + if level_to_check in new_loc: + new_level = int(new_loc[level_to_check]) + else: + new_level = None + + if new_level is not None and old_level is not None: + if new_level == old_level: + continue + + if dir == "users" and level_to_check != event.user_id: + if old_level == user_level: + raise AuthError( + 403, + "You don't have permission to remove ops level equal " + "to your own" + ) + + if old_level > user_level or new_level > user_level: + raise AuthError( + 403, + "You don't have permission to add ops level greater " + "than your own" + ) + + +def _get_power_level_event(auth_events): + key = (EventTypes.PowerLevels, "", ) + return auth_events.get(key) + + +def get_user_power_level(user_id, auth_events): + power_level_event = _get_power_level_event(auth_events) + + if power_level_event: + level = power_level_event.content.get("users", {}).get(user_id) + if not level: + level = power_level_event.content.get("users_default", 0) + + if level is None: + return 0 + else: + return int(level) + else: + key = (EventTypes.Create, "", ) + create_event = auth_events.get(key) + if (create_event is not None and + create_event.content["creator"] == user_id): + return 100 + else: + return 0 + + +def _get_named_level(auth_events, name, default): + power_level_event = _get_power_level_event(auth_events) + + if not power_level_event: + return default + + level = power_level_event.content.get(name, None) + if level is not None: + return int(level) + else: + return default + + +def _verify_third_party_invite(event, auth_events): + """ + Validates that the invite event is authorized by a previous third-party invite. + + Checks that the public key, and keyserver, match those in the third party invite, + and that the invite event has a signature issued using that public key. + + Args: + event: The m.room.member join event being validated. + auth_events: All relevant previous context events which may be used + for authorization decisions. + + Return: + True if the event fulfills the expectations of a previous third party + invite event. + """ + if "third_party_invite" not in event.content: + return False + if "signed" not in event.content["third_party_invite"]: + return False + signed = event.content["third_party_invite"]["signed"] + for key in {"mxid", "token"}: + if key not in signed: + return False + + token = signed["token"] + + invite_event = auth_events.get( + (EventTypes.ThirdPartyInvite, token,) + ) + if not invite_event: + return False + + if invite_event.sender != event.sender: + return False + + if event.user_id != invite_event.user_id: + return False + + if signed["mxid"] != event.state_key: + return False + if signed["token"] != token: + return False + + for public_key_object in get_public_keys(invite_event): + public_key = public_key_object["public_key"] + try: + for server, signature_block in signed["signatures"].items(): + for key_name, encoded_signature in signature_block.items(): + if not key_name.startswith("ed25519:"): + continue + verify_key = decode_verify_key_bytes( + key_name, + decode_base64(public_key) + ) + verify_signed_json(signed, server, verify_key) + + # We got the public key from the invite, so we know that the + # correct server signed the signed bundle. + # The caller is responsible for checking that the signing + # server has not revoked that public key. + return True + except (KeyError, SignatureVerifyException,): + continue + return False + + +def get_public_keys(invite_event): + public_keys = [] + if "public_key" in invite_event.content: + o = { + "public_key": invite_event.content["public_key"], + } + if "key_validity_url" in invite_event.content: + o["key_validity_url"] = invite_event.content["key_validity_url"] + public_keys.append(o) + public_keys.extend(invite_event.content.get("public_keys", [])) + return public_keys -- cgit 1.4.1 From a3e4a198e3f5e0acd91d40d5743f97ece2cf5b6f Mon Sep 17 00:00:00 2001 From: Adrian Perez de Castro Date: Fri, 13 Jan 2017 17:12:04 +0200 Subject: Allow configuring the Riot URL used in notification emails The URLs used for notification emails were hardcoded to use either matrix.to or vector.im; but for self-hosted setups where Riot is also self-hosted it may be desirable to allow configuring an alternative Riot URL. Fixes #1809. Signed-off-by: Adrian Perez de Castro --- synapse/config/emailconfig.py | 7 +++++++ synapse/push/mailer.py | 20 ++++++++++++++------ 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index a187161272..0030b5db1e 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -68,6 +68,9 @@ class EmailConfig(Config): self.email_notif_for_new_users = email_config.get( "notif_for_new_users", True ) + self.email_riot_base_url = email_config.get( + "riot_base_url", None + ) if "app_name" in email_config: self.email_app_name = email_config["app_name"] else: @@ -85,6 +88,9 @@ class EmailConfig(Config): def default_config(self, config_dir_path, server_name, **kwargs): return """ # Enable sending emails for notification events + # Defining a custom URL for Riot is only needed if email notifications + # should contain links to a self-hosted installation of Riot; when set + # the "app_name" setting is ignored. #email: # enable_notifs: false # smtp_host: "localhost" @@ -95,4 +101,5 @@ class EmailConfig(Config): # notif_template_html: notif_mail.html # notif_template_text: notif_mail.txt # notif_for_new_users: True + # riot_base_url: "http://localhost/riot" """ diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 53551632b6..ce2d31fb98 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -439,15 +439,23 @@ class Mailer(object): }) def make_room_link(self, room_id): - # need /beta for Universal Links to work on iOS - if self.app_name == "Vector": - return "https://vector.im/beta/#/room/%s" % (room_id,) + if self.hs.config.email_riot_base_url: + base_url = self.hs.config.email_riot_base_url + elif self.app_name == "Vector": + # need /beta for Universal Links to work on iOS + base_url = "https://vector.im/beta/#/room" else: - return "https://matrix.to/#/%s" % (room_id,) + base_url = "https://matrix.to/#" + return "%s/%s" % (base_url, room_id) def make_notif_link(self, notif): - # need /beta for Universal Links to work on iOS - if self.app_name == "Vector": + if self.hs.config.email_riot_base_url: + return "%s/#/room/%s/%s" % ( + self.hs.config.email_riot_base_url, + notif['room_id'], notif['event_id'] + ) + elif self.app_name == "Vector": + # need /beta for Universal Links to work on iOS return "https://vector.im/beta/#/room/%s/%s" % ( notif['room_id'], notif['event_id'] ) -- cgit 1.4.1 From c050f493dd53a74206338f9a5e567d7bd24fbd5d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 13 Jan 2017 15:14:41 +0000 Subject: Add comment --- synapse/storage/schema/delta/40/device_inbox.sql | 1 + 1 file changed, 1 insertion(+) diff --git a/synapse/storage/schema/delta/40/device_inbox.sql b/synapse/storage/schema/delta/40/device_inbox.sql index ce58fe2082..b9fe1f0480 100644 --- a/synapse/storage/schema/delta/40/device_inbox.sql +++ b/synapse/storage/schema/delta/40/device_inbox.sql @@ -13,6 +13,7 @@ * limitations under the License. */ +-- turn the pre-fill startup query into a index-only scan on postgresql. INSERT into background_updates (update_name, progress_json) VALUES ('device_inbox_stream_index', '{}'); -- cgit 1.4.1 From e178feca3f9063c7a4f768298e889ee54b471e9b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 13 Jan 2017 15:16:45 +0000 Subject: Remove unused function --- synapse/api/auth.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/synapse/api/auth.py b/synapse/api/auth.py index b781d41a66..280d4c4452 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -561,9 +561,6 @@ class Auth(object): defer.returnValue(auth_ids) - def _get_send_level(self, etype, state_key, auth_events): - return event_auth._get_send_level(etype, state_key, auth_events) - def check_redaction(self, event, auth_events): """Check whether the event sender is allowed to redact the target event. -- cgit 1.4.1 From ec0a523ac338bab1eb23a6b21227b8f7402cc2d4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Jan 2017 18:37:18 +0000 Subject: Split out static state methods from StateHandler --- synapse/state.py | 144 ++++++++++++++++++++++++++++--------------------------- 1 file changed, 74 insertions(+), 70 deletions(-) diff --git a/synapse/state.py b/synapse/state.py index b9d5627a82..c75499c3e0 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -16,6 +16,7 @@ from twisted.internet import defer +from synapse import event_auth from synapse.util.logutils import log_function from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure @@ -335,9 +336,10 @@ class StateHandler(object): [state_map[e_id] for key, e_id in st.items() if e_id in state_map] for st in state_groups_ids.values() ] - new_state, _ = self._resolve_events( - state_sets, event_type, state_key - ) + with Measure(self.clock, "state._resolve_events"): + new_state, _ = Resolver.resolve_events( + state_sets, event_type, state_key + ) new_state = { key: e.event_id for key, e in new_state.items() } @@ -388,68 +390,78 @@ class StateHandler(object): logger.info( "Resolving state for %s with %d groups", event.room_id, len(state_sets) ) - if event.is_state(): - return self._resolve_events( - state_sets, event.type, event.state_key - ) - else: - return self._resolve_events(state_sets) + with Measure(self.clock, "state._resolve_events"): + if event.is_state(): + return Resolver.resolve_events( + state_sets, event.type, event.state_key + ) + else: + return Resolver.resolve_events(state_sets) + - def _resolve_events(self, state_sets, event_type=None, state_key=""): +def _ordered_events(events): + def key_func(e): + return -int(e.depth), hashlib.sha1(e.event_id).hexdigest() + + return sorted(events, key=key_func) + + +class Resolver(object): + @staticmethod + def resolve_events(state_sets, event_type=None, state_key=""): """ Returns (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple (new_state, prev_states). new_state is a map from (type, state_key) to event. prev_states is a list of event_ids. """ - with Measure(self.clock, "state._resolve_events"): - state = {} - for st in state_sets: - for e in st: - state.setdefault( - (e.type, e.state_key), - {} - )[e.event_id] = e - - unconflicted_state = { - k: v.values()[0] for k, v in state.items() - if len(v.values()) == 1 - } + state = {} + for st in state_sets: + for e in st: + state.setdefault( + (e.type, e.state_key), + {} + )[e.event_id] = e + + unconflicted_state = { + k: v.values()[0] for k, v in state.items() + if len(v.values()) == 1 + } - conflicted_state = { - k: v.values() - for k, v in state.items() - if len(v.values()) > 1 - } + conflicted_state = { + k: v.values() + for k, v in state.items() + if len(v.values()) > 1 + } - if event_type: - prev_states_events = conflicted_state.get( - (event_type, state_key), [] - ) - prev_states = [s.event_id for s in prev_states_events] - else: - prev_states = [] + if event_type: + prev_states_events = conflicted_state.get( + (event_type, state_key), [] + ) + prev_states = [s.event_id for s in prev_states_events] + else: + prev_states = [] - auth_events = { - k: e for k, e in unconflicted_state.items() - if k[0] in AuthEventTypes - } + auth_events = { + k: e for k, e in unconflicted_state.items() + if k[0] in AuthEventTypes + } - try: - resolved_state = self._resolve_state_events( - conflicted_state, auth_events - ) - except: - logger.exception("Failed to resolve state") - raise + try: + resolved_state = Resolver._resolve_state_events( + conflicted_state, auth_events + ) + except: + logger.exception("Failed to resolve state") + raise - new_state = unconflicted_state - new_state.update(resolved_state) + new_state = unconflicted_state + new_state.update(resolved_state) return new_state, prev_states - @log_function - def _resolve_state_events(self, conflicted_state, auth_events): + @staticmethod + def _resolve_state_events(conflicted_state, auth_events): """ This is where we actually decide which of the conflicted state to use. @@ -464,7 +476,7 @@ class StateHandler(object): if power_key in conflicted_state: events = conflicted_state[power_key] logger.debug("Resolving conflicted power levels %r", events) - resolved_state[power_key] = self._resolve_auth_events( + resolved_state[power_key] = Resolver._resolve_auth_events( events, auth_events) auth_events.update(resolved_state) @@ -472,7 +484,7 @@ class StateHandler(object): for key, events in conflicted_state.items(): if key[0] == EventTypes.JoinRules: logger.debug("Resolving conflicted join rules %r", events) - resolved_state[key] = self._resolve_auth_events( + resolved_state[key] = Resolver._resolve_auth_events( events, auth_events ) @@ -482,7 +494,7 @@ class StateHandler(object): for key, events in conflicted_state.items(): if key[0] == EventTypes.Member: logger.debug("Resolving conflicted member lists %r", events) - resolved_state[key] = self._resolve_auth_events( + resolved_state[key] = Resolver._resolve_auth_events( events, auth_events ) @@ -492,14 +504,15 @@ class StateHandler(object): for key, events in conflicted_state.items(): if key not in resolved_state: logger.debug("Resolving conflicted state %r:%r", key, events) - resolved_state[key] = self._resolve_normal_events( + resolved_state[key] = Resolver._resolve_normal_events( events, auth_events ) return resolved_state - def _resolve_auth_events(self, events, auth_events): - reverse = [i for i in reversed(self._ordered_events(events))] + @staticmethod + def _resolve_auth_events(events, auth_events): + reverse = [i for i in reversed(_ordered_events(events))] auth_events = dict(auth_events) @@ -507,23 +520,20 @@ class StateHandler(object): for event in reverse[1:]: auth_events[(prev_event.type, prev_event.state_key)] = prev_event try: - # FIXME: hs.get_auth() is bad style, but we need to do it to - # get around circular deps. # The signatures have already been checked at this point - self.hs.get_auth().check(event, auth_events, do_sig_check=False) + event_auth.check(event, auth_events, do_sig_check=False) prev_event = event except AuthError: return prev_event return event - def _resolve_normal_events(self, events, auth_events): - for event in self._ordered_events(events): + @staticmethod + def _resolve_normal_events(events, auth_events): + for event in _ordered_events(events): try: - # FIXME: hs.get_auth() is bad style, but we need to do it to - # get around circular deps. # The signatures have already been checked at this point - self.hs.get_auth().check(event, auth_events, do_sig_check=False) + event_auth.check(event, auth_events, do_sig_check=False) return event except AuthError: pass @@ -531,9 +541,3 @@ class StateHandler(object): # Use the last event (the one with the least depth) if they all fail # the auth check. return event - - def _ordered_events(self, events): - def key_func(e): - return -int(e.depth), hashlib.sha1(e.event_id).hexdigest() - - return sorted(events, key=key_func) -- cgit 1.4.1 From 2fae34bd2ce152b8544d5a90fe3b35281c5fffbc Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 13 Jan 2017 17:46:17 +0000 Subject: Optionally measure size of cache by sum of length of values --- synapse/storage/roommember.py | 3 ++- synapse/storage/state.py | 2 +- synapse/util/caches/descriptors.py | 25 ++++++++++++++++++++----- synapse/util/caches/lrucache.py | 32 ++++++++++++++++++-------------- tests/util/test_lrucache.py | 25 +++++++++++++++++++++++++ 5 files changed, 66 insertions(+), 21 deletions(-) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 5d18037c7c..e63aab6ccf 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -390,7 +390,8 @@ class RoomMemberStore(SQLBaseStore): room_id, state_group, state_ids, ) - @cachedInlineCallbacks(num_args=2, cache_context=True) + @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True, + max_entries=2000) def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, cache_context, event=None): # We don't use `state_group`, it's there so that we can cache based diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 7f466c40ac..c480743f89 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -284,7 +284,7 @@ class StateStore(SQLBaseStore): return [r[0] for r in results] return self.runInteraction("get_current_state_for_key", f) - @cached(num_args=2, max_entries=1000) + @cached(num_args=2, max_entries=1000, iterable=True) def _get_state_group_from_group(self, group, types): raise NotImplementedError() diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 8dba61d49f..d082c26b1f 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -42,6 +42,13 @@ _CacheSentinel = object() CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) +def deferred_size(deferred): + if deferred.called: + return len(deferred.result) + else: + return 1 + + class Cache(object): __slots__ = ( "cache", @@ -53,10 +60,11 @@ class Cache(object): "metrics", ) - def __init__(self, name, max_entries=1000, keylen=1, tree=False): + def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False): cache_type = TreeCache if tree else dict self.cache = LruCache( - max_size=max_entries, keylen=keylen, cache_type=cache_type + max_size=max_entries, keylen=keylen, cache_type=cache_type, + size_callback=deferred_size if iterable else None, ) self.name = name @@ -155,7 +163,7 @@ class CacheDescriptor(object): """ def __init__(self, orig, max_entries=1000, num_args=1, tree=False, - inlineCallbacks=False, cache_context=False): + inlineCallbacks=False, cache_context=False, iterable=False): max_entries = int(max_entries * CACHE_SIZE_FACTOR) self.orig = orig @@ -169,6 +177,8 @@ class CacheDescriptor(object): self.num_args = num_args self.tree = tree + self.iterable = iterable + all_args = inspect.getargspec(orig) self.arg_names = all_args.args[1:num_args + 1] @@ -203,6 +213,7 @@ class CacheDescriptor(object): max_entries=self.max_entries, keylen=self.num_args, tree=self.tree, + iterable=self.iterable, ) @functools.wraps(self.orig) @@ -421,17 +432,20 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): self.cache.invalidate(self.key) -def cached(max_entries=1000, num_args=1, tree=False, cache_context=False): +def cached(max_entries=1000, num_args=1, tree=False, cache_context=False, + iterable=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, num_args=num_args, tree=tree, cache_context=cache_context, + iterable=iterable, ) -def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False): +def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False, + iterable=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, @@ -439,6 +453,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex tree=tree, inlineCallbacks=True, cache_context=cache_context, + iterable=iterable, ) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 9c4c679175..00ddf38290 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -49,7 +49,7 @@ class LruCache(object): Can also set callbacks on objects when getting/setting which are fired when that key gets invalidated/evicted. """ - def __init__(self, max_size, keylen=1, cache_type=dict): + def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None): cache = cache_type() self.cache = cache # Used for introspection. list_root = _Node(None, None, None, None) @@ -58,6 +58,18 @@ class LruCache(object): lock = threading.Lock() + def cache_len(): + if size_callback is not None: + return sum(size_callback(node.value) for node in cache.itervalues()) + else: + return len(cache) + + def evict(): + while cache_len() > max_size: + todelete = list_root.prev_node + delete_node(todelete) + cache.pop(todelete.key, None) + def synchronized(f): @wraps(f) def inner(*args, **kwargs): @@ -127,22 +139,18 @@ class LruCache(object): else: callbacks = set() add_node(key, value, callbacks) - if len(cache) > max_size: - todelete = list_root.prev_node - delete_node(todelete) - cache.pop(todelete.key, None) + + evict() @synchronized def cache_set_default(key, value): node = cache.get(key, None) if node is not None: + evict() # As the new node may be bigger than the old node. return node.value else: add_node(key, value) - if len(cache) > max_size: - todelete = list_root.prev_node - delete_node(todelete) - cache.pop(todelete.key, None) + evict() return value @synchronized @@ -175,10 +183,6 @@ class LruCache(object): cb() cache.clear() - @synchronized - def cache_len(): - return len(cache) - @synchronized def cache_contains(key): return key in cache @@ -190,7 +194,7 @@ class LruCache(object): self.pop = cache_pop if cache_type is TreeCache: self.del_multi = cache_del_multi - self.len = cache_len + self.len = synchronized(cache_len) self.contains = cache_contains self.clear = cache_clear diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 1eba5b535e..d888a64d0a 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -232,3 +232,28 @@ class LruCacheCallbacksTestCase(unittest.TestCase): self.assertEquals(m1.call_count, 1) self.assertEquals(m2.call_count, 0) self.assertEquals(m3.call_count, 1) + + +class LruCacheSizedTestCase(unittest.TestCase): + + def test_evict(self): + cache = LruCache(5, size_callback=len) + cache["key1"] = [0] + cache["key2"] = [1, 2] + cache["key3"] = [3] + cache["key4"] = [4] + + self.assertEquals(cache["key1"], [0]) + self.assertEquals(cache["key2"], [1, 2]) + self.assertEquals(cache["key3"], [3]) + self.assertEquals(cache["key4"], [4]) + self.assertEquals(len(cache), 5) + + cache["key5"] = [5, 6] + + self.assertEquals(len(cache), 4) + self.assertEquals(cache.get("key1"), None) + self.assertEquals(cache.get("key2"), None) + self.assertEquals(cache["key3"], [3]) + self.assertEquals(cache["key4"], [4]) + self.assertEquals(cache["key5"], [5, 6]) -- cgit 1.4.1 From 01521299c7d6d65b0f8b567bc7b7dbf94b7a81ce Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 16 Jan 2017 11:56:51 +0000 Subject: Increase cache size limit --- synapse/storage/roommember.py | 2 +- synapse/storage/state.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index e63aab6ccf..8dce89073d 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -391,7 +391,7 @@ class RoomMemberStore(SQLBaseStore): ) @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True, - max_entries=2000) + max_entries=50000) def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, cache_context, event=None): # We don't use `state_group`, it's there so that we can cache based diff --git a/synapse/storage/state.py b/synapse/storage/state.py index c480743f89..fe942ecad7 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -284,7 +284,7 @@ class StateStore(SQLBaseStore): return [r[0] for r in results] return self.runInteraction("get_current_state_for_key", f) - @cached(num_args=2, max_entries=1000, iterable=True) + @cached(num_args=2, max_entries=50000, iterable=True) def _get_state_group_from_group(self, group, types): raise NotImplementedError() -- cgit 1.4.1 From 46aebbbcbf94eb78ae45d3bb3bf3ffeabb44dd4f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 16 Jan 2017 13:48:04 +0000 Subject: Add support for 'iterable' to ExpiringCache --- synapse/state.py | 6 +++++- synapse/util/caches/expiringcache.py | 26 +++++++++++++++++--------- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/synapse/state.py b/synapse/state.py index b9d5627a82..461e82acdf 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -41,7 +41,7 @@ KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) -SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR) +SIZE_OF_CACHE = int(10000 * CACHE_SIZE_FACTOR) EVICTION_TIMEOUT_SECONDS = 60 * 60 @@ -77,6 +77,9 @@ class _StateCacheEntry(object): else: self.state_id = _gen_state_id() + def __len__(self): + return len(self.state) + class StateHandler(object): """ Responsible for doing state conflict resolution. @@ -99,6 +102,7 @@ class StateHandler(object): clock=self.clock, max_len=SIZE_OF_CACHE, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, + iterable=True, reset_expiry_on_get=True, ) diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 080388958f..9b44b3fab3 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) class ExpiringCache(object): def __init__(self, cache_name, clock, max_len=0, expiry_ms=0, - reset_expiry_on_get=False): + reset_expiry_on_get=False, iterable=False): """ Args: cache_name (str): Name of this cache, used for logging. @@ -36,6 +36,8 @@ class ExpiringCache(object): evicted based on time. reset_expiry_on_get (bool): If true, will reset the expiry time for an item on access. Defaults to False. + iterable (bool): If true, the size is calculated by summing the + sizes of all entries, rather than the number of entries. """ self._cache_name = cache_name @@ -49,7 +51,9 @@ class ExpiringCache(object): self._cache = {} - self.metrics = register_cache(cache_name, self._cache) + self.metrics = register_cache(cache_name, self) + + self.iterable = iterable def start(self): if not self._expiry_ms: @@ -66,14 +70,15 @@ class ExpiringCache(object): self._cache[key] = _CacheEntry(now, value) # Evict if there are now too many items - if self._max_len and len(self._cache.keys()) > self._max_len: + if self._max_len and len(self) > self._max_len: sorted_entries = sorted( - self._cache.items(), + self._cache.keys(), key=lambda item: item[1].time, ) - for k, _ in sorted_entries[self._max_len:]: - self._cache.pop(k) + while len(self) > self._max_len and sorted_entries: + key = sorted_entries.pop() + self._cache.pop(key) def __getitem__(self, key): try: @@ -99,7 +104,7 @@ class ExpiringCache(object): # zero expiry time means don't expire. This should never get called # since we have this check in start too. return - begin_length = len(self._cache) + begin_length = len(self) now = self._clock.time_msec() @@ -114,11 +119,14 @@ class ExpiringCache(object): logger.debug( "[%s] _prune_cache before: %d, after len: %d", - self._cache_name, begin_length, len(self._cache) + self._cache_name, begin_length, len(self) ) def __len__(self): - return len(self._cache) + if self.iterable: + return sum(len(value.value) for value in self._cache.itervalues()) + else: + return len(self._cache) class _CacheEntry(object): -- cgit 1.4.1 From beda469bc6e96a0b776c3d6742cf97950819b2f0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 16 Jan 2017 15:05:24 +0000 Subject: Put staticmethods at module level --- synapse/state.py | 244 +++++++++++++++++++++++++++---------------------------- 1 file changed, 121 insertions(+), 123 deletions(-) diff --git a/synapse/state.py b/synapse/state.py index c75499c3e0..90b14e758c 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -337,7 +337,7 @@ class StateHandler(object): for st in state_groups_ids.values() ] with Measure(self.clock, "state._resolve_events"): - new_state, _ = Resolver.resolve_events( + new_state, _ = resolve_events( state_sets, event_type, state_key ) new_state = { @@ -392,11 +392,11 @@ class StateHandler(object): ) with Measure(self.clock, "state._resolve_events"): if event.is_state(): - return Resolver.resolve_events( + return resolve_events( state_sets, event.type, event.state_key ) else: - return Resolver.resolve_events(state_sets) + return resolve_events(state_sets) def _ordered_events(events): @@ -406,138 +406,136 @@ def _ordered_events(events): return sorted(events, key=key_func) -class Resolver(object): - @staticmethod - def resolve_events(state_sets, event_type=None, state_key=""): - """ - Returns - (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple - (new_state, prev_states). new_state is a map from (type, state_key) - to event. prev_states is a list of event_ids. - """ - state = {} - for st in state_sets: - for e in st: - state.setdefault( - (e.type, e.state_key), - {} - )[e.event_id] = e - - unconflicted_state = { - k: v.values()[0] for k, v in state.items() - if len(v.values()) == 1 - } +def resolve_events(state_sets, event_type=None, state_key=""): + """ + Returns + (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple + (new_state, prev_states). new_state is a map from (type, state_key) + to event. prev_states is a list of event_ids. + """ + state = {} + for st in state_sets: + for e in st: + state.setdefault( + (e.type, e.state_key), + {} + )[e.event_id] = e + + unconflicted_state = { + k: v.values()[0] for k, v in state.items() + if len(v.values()) == 1 + } + + conflicted_state = { + k: v.values() + for k, v in state.items() + if len(v.values()) > 1 + } + + if event_type: + prev_states_events = conflicted_state.get( + (event_type, state_key), [] + ) + prev_states = [s.event_id for s in prev_states_events] + else: + prev_states = [] + + auth_events = { + k: e for k, e in unconflicted_state.items() + if k[0] in AuthEventTypes + } + + try: + resolved_state = _resolve_state_events( + conflicted_state, auth_events + ) + except: + logger.exception("Failed to resolve state") + raise - conflicted_state = { - k: v.values() - for k, v in state.items() - if len(v.values()) > 1 - } + new_state = unconflicted_state + new_state.update(resolved_state) - if event_type: - prev_states_events = conflicted_state.get( - (event_type, state_key), [] + return new_state, prev_states + + +def _resolve_state_events(conflicted_state, auth_events): + """ This is where we actually decide which of the conflicted state to + use. + + We resolve conflicts in the following order: + 1. power levels + 2. join rules + 3. memberships + 4. other events. + """ + resolved_state = {} + power_key = (EventTypes.PowerLevels, "") + if power_key in conflicted_state: + events = conflicted_state[power_key] + logger.debug("Resolving conflicted power levels %r", events) + resolved_state[power_key] = _resolve_auth_events( + events, auth_events) + + auth_events.update(resolved_state) + + for key, events in conflicted_state.items(): + if key[0] == EventTypes.JoinRules: + logger.debug("Resolving conflicted join rules %r", events) + resolved_state[key] = _resolve_auth_events( + events, + auth_events ) - prev_states = [s.event_id for s in prev_states_events] - else: - prev_states = [] - auth_events = { - k: e for k, e in unconflicted_state.items() - if k[0] in AuthEventTypes - } + auth_events.update(resolved_state) - try: - resolved_state = Resolver._resolve_state_events( - conflicted_state, auth_events + for key, events in conflicted_state.items(): + if key[0] == EventTypes.Member: + logger.debug("Resolving conflicted member lists %r", events) + resolved_state[key] = _resolve_auth_events( + events, + auth_events ) - except: - logger.exception("Failed to resolve state") - raise - new_state = unconflicted_state - new_state.update(resolved_state) + auth_events.update(resolved_state) - return new_state, prev_states + for key, events in conflicted_state.items(): + if key not in resolved_state: + logger.debug("Resolving conflicted state %r:%r", key, events) + resolved_state[key] = _resolve_normal_events( + events, auth_events + ) - @staticmethod - def _resolve_state_events(conflicted_state, auth_events): - """ This is where we actually decide which of the conflicted state to - use. + return resolved_state - We resolve conflicts in the following order: - 1. power levels - 2. join rules - 3. memberships - 4. other events. - """ - resolved_state = {} - power_key = (EventTypes.PowerLevels, "") - if power_key in conflicted_state: - events = conflicted_state[power_key] - logger.debug("Resolving conflicted power levels %r", events) - resolved_state[power_key] = Resolver._resolve_auth_events( - events, auth_events) - - auth_events.update(resolved_state) - - for key, events in conflicted_state.items(): - if key[0] == EventTypes.JoinRules: - logger.debug("Resolving conflicted join rules %r", events) - resolved_state[key] = Resolver._resolve_auth_events( - events, - auth_events - ) - auth_events.update(resolved_state) +def _resolve_auth_events(events, auth_events): + reverse = [i for i in reversed(_ordered_events(events))] - for key, events in conflicted_state.items(): - if key[0] == EventTypes.Member: - logger.debug("Resolving conflicted member lists %r", events) - resolved_state[key] = Resolver._resolve_auth_events( - events, - auth_events - ) + auth_events = dict(auth_events) + + prev_event = reverse[0] + for event in reverse[1:]: + auth_events[(prev_event.type, prev_event.state_key)] = prev_event + try: + # The signatures have already been checked at this point + event_auth.check(event, auth_events, do_sig_check=False) + prev_event = event + except AuthError: + return prev_event - auth_events.update(resolved_state) + return event - for key, events in conflicted_state.items(): - if key not in resolved_state: - logger.debug("Resolving conflicted state %r:%r", key, events) - resolved_state[key] = Resolver._resolve_normal_events( - events, auth_events - ) - return resolved_state - - @staticmethod - def _resolve_auth_events(events, auth_events): - reverse = [i for i in reversed(_ordered_events(events))] - - auth_events = dict(auth_events) - - prev_event = reverse[0] - for event in reverse[1:]: - auth_events[(prev_event.type, prev_event.state_key)] = prev_event - try: - # The signatures have already been checked at this point - event_auth.check(event, auth_events, do_sig_check=False) - prev_event = event - except AuthError: - return prev_event - - return event - - @staticmethod - def _resolve_normal_events(events, auth_events): - for event in _ordered_events(events): - try: - # The signatures have already been checked at this point - event_auth.check(event, auth_events, do_sig_check=False) - return event - except AuthError: - pass - - # Use the last event (the one with the least depth) if they all fail - # the auth check. - return event +def _resolve_normal_events(events, auth_events): + for event in _ordered_events(events): + try: + # The signatures have already been checked at this point + event_auth.check(event, auth_events, do_sig_check=False) + return event + except AuthError: + pass + + # Use the last event (the one with the least depth) if they all fail + # the auth check. + return event -- cgit 1.4.1 From 897f8752da3c9f7b2d214fe91e8356be5db545c3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 16 Jan 2017 15:08:17 +0000 Subject: Up cache max entries for state --- synapse/state.py | 2 +- synapse/storage/roommember.py | 2 +- synapse/storage/state.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/synapse/state.py b/synapse/state.py index 461e82acdf..66e1a685e8 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -41,7 +41,7 @@ KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) -SIZE_OF_CACHE = int(10000 * CACHE_SIZE_FACTOR) +SIZE_OF_CACHE = int(100000 * CACHE_SIZE_FACTOR) EVICTION_TIMEOUT_SECONDS = 60 * 60 diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 8dce89073d..768e0a4451 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -391,7 +391,7 @@ class RoomMemberStore(SQLBaseStore): ) @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True, - max_entries=50000) + max_entries=100000) def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, cache_context, event=None): # We don't use `state_group`, it's there so that we can cache based diff --git a/synapse/storage/state.py b/synapse/storage/state.py index fe942ecad7..7d34dd03bf 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -284,7 +284,7 @@ class StateStore(SQLBaseStore): return [r[0] for r in results] return self.runInteraction("get_current_state_for_key", f) - @cached(num_args=2, max_entries=50000, iterable=True) + @cached(num_args=2, max_entries=100000, iterable=True) def _get_state_group_from_group(self, group, types): raise NotImplementedError() -- cgit 1.4.1 From 6d00213e80fa51380c8ad7b339e7420edec27f9a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 16 Jan 2017 15:33:22 +0000 Subject: Use OrderedDict in ExpiringCache --- synapse/util/caches/expiringcache.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 9b44b3fab3..b9ead9cbd5 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -15,6 +15,7 @@ from synapse.util.caches import register_cache +from collections import OrderedDict import logging @@ -49,7 +50,7 @@ class ExpiringCache(object): self._reset_expiry_on_get = reset_expiry_on_get - self._cache = {} + self._cache = OrderedDict() self.metrics = register_cache(cache_name, self) @@ -70,15 +71,8 @@ class ExpiringCache(object): self._cache[key] = _CacheEntry(now, value) # Evict if there are now too many items - if self._max_len and len(self) > self._max_len: - sorted_entries = sorted( - self._cache.keys(), - key=lambda item: item[1].time, - ) - - while len(self) > self._max_len and sorted_entries: - key = sorted_entries.pop() - self._cache.pop(key) + while self._max_len and len(self) > self._max_len: + self._cache.popitem(last=False) def __getitem__(self, key): try: -- cgit 1.4.1 From f2f179dce26f42ea0e691d17c60b297c63898923 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 16 Jan 2017 15:33:34 +0000 Subject: Add ExpiringCache tests --- tests/util/test_expiring_cache.py | 84 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 tests/util/test_expiring_cache.py diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py new file mode 100644 index 0000000000..31d24adb8b --- /dev/null +++ b/tests/util/test_expiring_cache.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .. import unittest + +from synapse.util.caches.expiringcache import ExpiringCache + +from tests.utils import MockClock + + +class ExpiringCacheTestCase(unittest.TestCase): + + def test_get_set(self): + clock = MockClock() + cache = ExpiringCache("test", clock, max_len=1) + + cache["key"] = "value" + self.assertEquals(cache.get("key"), "value") + self.assertEquals(cache["key"], "value") + + def test_eviction(self): + clock = MockClock() + cache = ExpiringCache("test", clock, max_len=2) + + cache["key"] = "value" + cache["key2"] = "value2" + self.assertEquals(cache.get("key"), "value") + self.assertEquals(cache.get("key2"), "value2") + + cache["key3"] = "value3" + self.assertEquals(cache.get("key"), None) + self.assertEquals(cache.get("key2"), "value2") + self.assertEquals(cache.get("key3"), "value3") + + def test_iterable_eviction(self): + clock = MockClock() + cache = ExpiringCache("test", clock, max_len=5, iterable=True) + + cache["key"] = [1] + cache["key2"] = [2, 3] + cache["key3"] = [4, 5] + + self.assertEquals(cache.get("key"), [1]) + self.assertEquals(cache.get("key2"), [2, 3]) + self.assertEquals(cache.get("key3"), [4, 5]) + + cache["key4"] = [6, 7] + self.assertEquals(cache.get("key"), None) + self.assertEquals(cache.get("key2"), None) + self.assertEquals(cache.get("key3"), [4, 5]) + self.assertEquals(cache.get("key4"), [6, 7]) + + def test_time_eviction(self): + clock = MockClock() + cache = ExpiringCache("test", clock, expiry_ms=1000) + cache.start() + + cache["key"] = 1 + clock.advance_time(0.5) + cache["key2"] = 2 + + self.assertEquals(cache.get("key"), 1) + self.assertEquals(cache.get("key2"), 2) + + clock.advance_time(0.9) + self.assertEquals(cache.get("key"), None) + self.assertEquals(cache.get("key2"), 2) + + clock.advance_time(1) + self.assertEquals(cache.get("key"), None) + self.assertEquals(cache.get("key2"), None) -- cgit 1.4.1 From f85b6ca494ae587731d99196020cc74d7eca012a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 11:18:13 +0000 Subject: Speed up cache size calculation Instead of calculating the size of the cache repeatedly, which can take a long time now that it can use a callback, instead cache the size and update that on insertion and deletion. This requires changing the cache descriptors to have two caches, one for pending deferreds and the other for the actual values. There's no reason to evict from the pending deferreds as they won't take up any more memory. --- synapse/util/caches/descriptors.py | 97 +++++++++++++++++++++++++-------- synapse/util/caches/dictionary_cache.py | 6 +- synapse/util/caches/expiringcache.py | 15 ++++- synapse/util/caches/lrucache.py | 42 ++++++++------ synapse/util/caches/treecache.py | 14 ++++- tests/storage/test__base.py | 6 +- tests/util/test_lrucache.py | 30 +++++----- 7 files changed, 148 insertions(+), 62 deletions(-) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index d082c26b1f..b3b2d6092d 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -17,7 +17,7 @@ import logging from synapse.util.async import ObservableDeferred from synapse.util import unwrapFirstError from synapse.util.caches.lrucache import LruCache -from synapse.util.caches.treecache import TreeCache +from synapse.util.caches.treecache import TreeCache, popped_to_iterator from synapse.util.logcontext import ( PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn ) @@ -42,11 +42,23 @@ _CacheSentinel = object() CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) -def deferred_size(deferred): - if deferred.called: - return len(deferred.result) - else: - return 1 +class CacheEntry(object): + __slots__ = [ + "deferred", "sequence", "callbacks", "invalidated" + ] + + def __init__(self, deferred, sequence, callbacks): + self.deferred = deferred + self.sequence = sequence + self.callbacks = set(callbacks) + self.invalidated = False + + def invalidate(self): + if not self.invalidated: + self.invalidated = True + for callback in self.callbacks: + callback() + self.callbacks.clear() class Cache(object): @@ -58,13 +70,16 @@ class Cache(object): "sequence", "thread", "metrics", + "_pending_deferred_cache", ) def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False): cache_type = TreeCache if tree else dict + self._pending_deferred_cache = cache_type() + self.cache = LruCache( max_size=max_entries, keylen=keylen, cache_type=cache_type, - size_callback=deferred_size if iterable else None, + size_callback=(lambda d: len(d.result)) if iterable else None, ) self.name = name @@ -84,7 +99,15 @@ class Cache(object): ) def get(self, key, default=_CacheSentinel, callback=None): - val = self.cache.get(key, _CacheSentinel, callback=callback) + callbacks = [callback] if callback else [] + val = self._pending_deferred_cache.get(key, _CacheSentinel) + if val is not _CacheSentinel: + if val.sequence == self.sequence: + val.callbacks.update(callbacks) + self.metrics.inc_hits() + return val.deferred + + val = self.cache.get(key, _CacheSentinel, callbacks=callbacks) if val is not _CacheSentinel: self.metrics.inc_hits() return val @@ -96,15 +119,39 @@ class Cache(object): else: return default - def update(self, sequence, key, value, callback=None): + def set(self, key, value, callback=None): + callbacks = [callback] if callback else [] self.check_thread() - if self.sequence == sequence: - # Only update the cache if the caches sequence number matches the - # number that the cache had before the SELECT was started (SYN-369) - self.prefill(key, value, callback=callback) + entry = CacheEntry( + deferred=value, + sequence=self.sequence, + callbacks=callbacks, + ) + + entry.callbacks.update(callbacks) + + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry: + existing_entry.invalidate() + + self._pending_deferred_cache[key] = entry + + def shuffle(result): + if self.sequence == entry.sequence: + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry is entry: + self.cache.set(key, entry.deferred, entry.callbacks) + else: + entry.invalidate() + else: + entry.invalidate() + return result + + entry.deferred.addCallback(shuffle) def prefill(self, key, value, callback=None): - self.cache.set(key, value, callback=callback) + callbacks = [callback] if callback else [] + self.cache.set(key, value, callbacks=callbacks) def invalidate(self, key): self.check_thread() @@ -116,6 +163,10 @@ class Cache(object): # Increment the sequence number so that any SELECT statements that # raced with the INSERT don't update the cache (SYN-369) self.sequence += 1 + entry = self._pending_deferred_cache.pop(key, None) + if entry: + entry.invalidate() + self.cache.pop(key, None) def invalidate_many(self, key): @@ -127,6 +178,12 @@ class Cache(object): self.sequence += 1 self.cache.del_multi(key) + val = self._pending_deferred_cache.pop(key, None) + if val is not None: + entry_dict, _ = val + for entry in popped_to_iterator(entry_dict): + entry.invalidate() + def invalidate_all(self): self.check_thread() self.sequence += 1 @@ -254,11 +311,6 @@ class CacheDescriptor(object): return preserve_context_over_deferred(observer) except KeyError: - # Get the sequence number of the cache before reading from the - # database so that we can tell if the cache is invalidated - # while the SELECT is executing (SYN-369) - sequence = cache.sequence - ret = defer.maybeDeferred( preserve_context_over_fn, self.function_to_call, @@ -272,7 +324,7 @@ class CacheDescriptor(object): ret.addErrback(onErr) ret = ObservableDeferred(ret, consumeErrors=True) - cache.update(sequence, cache_key, ret, callback=invalidate_callback) + cache.set(cache_key, ret, callback=invalidate_callback) return preserve_context_over_deferred(ret.observe()) @@ -370,7 +422,6 @@ class CacheListDescriptor(object): missing.append(arg) if missing: - sequence = cache.sequence args_to_call = dict(arg_dict) args_to_call[self.list_name] = missing @@ -393,8 +444,8 @@ class CacheListDescriptor(object): key = list(keyargs) key[self.list_pos] = arg - cache.update( - sequence, tuple(key), observer, + cache.set( + tuple(key), observer, callback=invalidate_callback ) diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index b0ca1bb79d..cb6933c61c 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -23,7 +23,9 @@ import logging logger = logging.getLogger(__name__) -DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value")) +class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "value"))): + def __len__(self): + return len(self.value) class DictionaryCache(object): @@ -32,7 +34,7 @@ class DictionaryCache(object): """ def __init__(self, name, max_entries=1000): - self.cache = LruCache(max_size=max_entries) + self.cache = LruCache(max_size=max_entries, size_callback=len) self.name = name self.sequence = 0 diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index b9ead9cbd5..2987c38a2d 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -56,6 +56,8 @@ class ExpiringCache(object): self.iterable = iterable + self._size_estimate = 0 + def start(self): if not self._expiry_ms: # Don't bother starting the loop if things never expire @@ -70,9 +72,14 @@ class ExpiringCache(object): now = self._clock.time_msec() self._cache[key] = _CacheEntry(now, value) + if self.iterable: + self._size_estimate += len(value) + # Evict if there are now too many items while self._max_len and len(self) > self._max_len: - self._cache.popitem(last=False) + _key, value = self._cache.popitem(last=False) + if self.iterable: + self._size_estimate -= len(value.value) def __getitem__(self, key): try: @@ -109,7 +116,9 @@ class ExpiringCache(object): keys_to_delete.add(key) for k in keys_to_delete: - self._cache.pop(k) + value = self._cache.pop(k) + if self.iterable: + self._size_estimate -= len(value.value) logger.debug( "[%s] _prune_cache before: %d, after len: %d", @@ -118,7 +127,7 @@ class ExpiringCache(object): def __len__(self): if self.iterable: - return sum(len(value.value) for value in self._cache.itervalues()) + return self._size_estimate else: return len(self._cache) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 00ddf38290..f1de034444 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -58,12 +58,6 @@ class LruCache(object): lock = threading.Lock() - def cache_len(): - if size_callback is not None: - return sum(size_callback(node.value) for node in cache.itervalues()) - else: - return len(cache) - def evict(): while cache_len() > max_size: todelete = list_root.prev_node @@ -78,6 +72,16 @@ class LruCache(object): return inner + cached_cache_len = [0] + if size_callback is not None: + def cache_len(): + return cached_cache_len[0] + else: + def cache_len(): + return len(cache) + + self.len = synchronized(cache_len) + def add_node(key, value, callbacks=set()): prev_node = list_root next_node = prev_node.next_node @@ -86,6 +90,9 @@ class LruCache(object): next_node.prev_node = node cache[key] = node + if size_callback: + cached_cache_len[0] += size_callback(node.value) + def move_node_to_front(node): prev_node = node.prev_node next_node = node.next_node @@ -104,23 +111,25 @@ class LruCache(object): prev_node.next_node = next_node next_node.prev_node = prev_node + if size_callback: + cached_cache_len[0] -= size_callback(node.value) + for cb in node.callbacks: cb() node.callbacks.clear() @synchronized - def cache_get(key, default=None, callback=None): + def cache_get(key, default=None, callbacks=[]): node = cache.get(key, None) if node is not None: move_node_to_front(node) - if callback: - node.callbacks.add(callback) + node.callbacks.update(callbacks) return node.value else: return default @synchronized - def cache_set(key, value, callback=None): + def cache_set(key, value, callbacks=[]): node = cache.get(key, None) if node is not None: if value != node.value: @@ -128,17 +137,16 @@ class LruCache(object): cb() node.callbacks.clear() - if callback: - node.callbacks.add(callback) + if size_callback: + cached_cache_len[0] -= size_callback(node.value) + cached_cache_len[0] += size_callback(value) + + node.callbacks.update(callbacks) move_node_to_front(node) node.value = value else: - if callback: - callbacks = set([callback]) - else: - callbacks = set() - add_node(key, value, callbacks) + add_node(key, value, set(callbacks)) evict() diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index c31585aea3..460e98a92d 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -65,12 +65,24 @@ class TreeCache(object): return popped def values(self): - return [e.value for e in self.root.values()] + return list(popped_to_iterator(self.root)) def __len__(self): return self.size +def popped_to_iterator(d): + if isinstance(d, dict): + for value_d in d.itervalues(): + for value in popped_to_iterator(value_d): + yield value + else: + if isinstance(d, _Entry): + yield d.value + else: + yield d + + class _Entry(object): __slots__ = ["value"] diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index ab6095564a..8361dd8cee 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -241,7 +241,7 @@ class CacheDecoratorTestCase(unittest.TestCase): callcount2 = [0] class A(object): - @cached(max_entries=2) + @cached(max_entries=20) # HACK: This makes it 2 due to cache factor def func(self, key): callcount[0] += 1 return key @@ -258,6 +258,10 @@ class CacheDecoratorTestCase(unittest.TestCase): self.assertEquals(callcount[0], 2) self.assertEquals(callcount2[0], 2) + yield a.func2("foo") + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 2) + yield a.func("foo3") self.assertEquals(callcount[0], 3) diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index d888a64d0a..99aab65001 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -93,7 +93,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): cache.set("key", "value") self.assertFalse(m.called) - cache.get("key", callback=m) + cache.get("key", callbacks=[m]) self.assertFalse(m.called) cache.get("key", "value") @@ -112,10 +112,10 @@ class LruCacheCallbacksTestCase(unittest.TestCase): cache.set("key", "value") self.assertFalse(m.called) - cache.get("key", callback=m) + cache.get("key", callbacks=[m]) self.assertFalse(m.called) - cache.get("key", callback=m) + cache.get("key", callbacks=[m]) self.assertFalse(m.called) cache.set("key", "value2") @@ -128,7 +128,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m = Mock() cache = LruCache(1) - cache.set("key", "value", m) + cache.set("key", "value", [m]) self.assertFalse(m.called) cache.set("key", "value") @@ -144,7 +144,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m = Mock() cache = LruCache(1) - cache.set("key", "value", m) + cache.set("key", "value", [m]) self.assertFalse(m.called) cache.pop("key") @@ -163,10 +163,10 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m4 = Mock() cache = LruCache(4, 2, cache_type=TreeCache) - cache.set(("a", "1"), "value", m1) - cache.set(("a", "2"), "value", m2) - cache.set(("b", "1"), "value", m3) - cache.set(("b", "2"), "value", m4) + cache.set(("a", "1"), "value", [m1]) + cache.set(("a", "2"), "value", [m2]) + cache.set(("b", "1"), "value", [m3]) + cache.set(("b", "2"), "value", [m4]) self.assertEquals(m1.call_count, 0) self.assertEquals(m2.call_count, 0) @@ -185,8 +185,8 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m2 = Mock() cache = LruCache(5) - cache.set("key1", "value", m1) - cache.set("key2", "value", m2) + cache.set("key1", "value", [m1]) + cache.set("key2", "value", [m2]) self.assertEquals(m1.call_count, 0) self.assertEquals(m2.call_count, 0) @@ -202,14 +202,14 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m3 = Mock(name="m3") cache = LruCache(2) - cache.set("key1", "value", m1) - cache.set("key2", "value", m2) + cache.set("key1", "value", [m1]) + cache.set("key2", "value", [m2]) self.assertEquals(m1.call_count, 0) self.assertEquals(m2.call_count, 0) self.assertEquals(m3.call_count, 0) - cache.set("key3", "value", m3) + cache.set("key3", "value", [m3]) self.assertEquals(m1.call_count, 1) self.assertEquals(m2.call_count, 0) @@ -227,7 +227,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): self.assertEquals(m2.call_count, 0) self.assertEquals(m3.call_count, 0) - cache.set("key1", "value", m1) + cache.set("key1", "value", [m1]) self.assertEquals(m1.call_count, 1) self.assertEquals(m2.call_count, 0) -- cgit 1.4.1 From d9062060499d670f41ebc31d43003bed3502a722 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 11:25:51 +0000 Subject: Increase state_group_cache_size --- synapse/storage/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 5620a655eb..963ef999d5 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -169,7 +169,7 @@ class SQLBaseStore(object): max_entries=hs.config.event_cache_size) self._state_group_cache = DictionaryCache( - "*stateGroupCache*", 2000 * CACHE_SIZE_FACTOR + "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR ) self._event_fetch_lock = threading.Condition() -- cgit 1.4.1 From 1ccd5676e3fe01bcc1c59fd06f400f629b24c3ba Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 11:42:26 +0000 Subject: Remove needless call to evict() --- synapse/util/caches/lrucache.py | 1 - 1 file changed, 1 deletion(-) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index f1de034444..072f9a9d19 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -154,7 +154,6 @@ class LruCache(object): def cache_set_default(key, value): node = cache.get(key, None) if node is not None: - evict() # As the new node may be bigger than the old node. return node.value else: add_node(key, value) -- cgit 1.4.1 From d6c75cb7c237a31252f0838d2aa6114cd58b2ad4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 11:44:57 +0000 Subject: Rename and comment tree_to_leaves_iterator --- synapse/util/caches/descriptors.py | 4 ++-- synapse/util/caches/treecache.py | 9 ++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index b3b2d6092d..a9ea97fd46 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -17,7 +17,7 @@ import logging from synapse.util.async import ObservableDeferred from synapse.util import unwrapFirstError from synapse.util.caches.lrucache import LruCache -from synapse.util.caches.treecache import TreeCache, popped_to_iterator +from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry from synapse.util.logcontext import ( PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn ) @@ -181,7 +181,7 @@ class Cache(object): val = self._pending_deferred_cache.pop(key, None) if val is not None: entry_dict, _ = val - for entry in popped_to_iterator(entry_dict): + for entry in iterate_tree_cache_entry(entry_dict): entry.invalidate() def invalidate_all(self): diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index 460e98a92d..fcc341a6b7 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -65,16 +65,19 @@ class TreeCache(object): return popped def values(self): - return list(popped_to_iterator(self.root)) + return list(iterate_tree_cache_entry(self.root)) def __len__(self): return self.size -def popped_to_iterator(d): +def iterate_tree_cache_entry(d): + """Helper function to iterate over the leaves of a tree, i.e. a dict of that + can contain dicts. + """ if isinstance(d, dict): for value_d in d.itervalues(): - for value in popped_to_iterator(value_d): + for value in iterate_tree_cache_entry(value_d): yield value else: if isinstance(d, _Entry): -- cgit 1.4.1 From 9e8e236d9816ef639bdeb72cbb4de0fc29c6b120 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 11:48:02 +0000 Subject: Tidy up test --- tests/util/test_lrucache.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 99aab65001..dfb78cb8bd 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -128,7 +128,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m = Mock() cache = LruCache(1) - cache.set("key", "value", [m]) + cache.set("key", "value", callbacks=[m]) self.assertFalse(m.called) cache.set("key", "value") @@ -144,7 +144,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m = Mock() cache = LruCache(1) - cache.set("key", "value", [m]) + cache.set("key", "value", callbacks=[m]) self.assertFalse(m.called) cache.pop("key") @@ -163,10 +163,10 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m4 = Mock() cache = LruCache(4, 2, cache_type=TreeCache) - cache.set(("a", "1"), "value", [m1]) - cache.set(("a", "2"), "value", [m2]) - cache.set(("b", "1"), "value", [m3]) - cache.set(("b", "2"), "value", [m4]) + cache.set(("a", "1"), "value", callbacks=[m1]) + cache.set(("a", "2"), "value", callbacks=[m2]) + cache.set(("b", "1"), "value", callbacks=[m3]) + cache.set(("b", "2"), "value", callbacks=[m4]) self.assertEquals(m1.call_count, 0) self.assertEquals(m2.call_count, 0) @@ -185,8 +185,8 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m2 = Mock() cache = LruCache(5) - cache.set("key1", "value", [m1]) - cache.set("key2", "value", [m2]) + cache.set("key1", "value", callbacks=[m1]) + cache.set("key2", "value", callbacks=[m2]) self.assertEquals(m1.call_count, 0) self.assertEquals(m2.call_count, 0) @@ -202,14 +202,14 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m3 = Mock(name="m3") cache = LruCache(2) - cache.set("key1", "value", [m1]) - cache.set("key2", "value", [m2]) + cache.set("key1", "value", callbacks=[m1]) + cache.set("key2", "value", callbacks=[m2]) self.assertEquals(m1.call_count, 0) self.assertEquals(m2.call_count, 0) self.assertEquals(m3.call_count, 0) - cache.set("key3", "value", [m3]) + cache.set("key3", "value", callbacks=[m3]) self.assertEquals(m1.call_count, 1) self.assertEquals(m2.call_count, 0) @@ -227,7 +227,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): self.assertEquals(m2.call_count, 0) self.assertEquals(m3.call_count, 0) - cache.set("key1", "value", [m1]) + cache.set("key1", "value", callbacks=[m1]) self.assertEquals(m1.call_count, 1) self.assertEquals(m2.call_count, 0) -- cgit 1.4.1 From 5d6bad1b3c325897db81f84ebfc67ca687d851c0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 13 Jan 2017 13:16:54 +0000 Subject: Optimise state resolution --- synapse/event_auth.py | 49 ++++++++-- synapse/events/__init__.py | 8 +- synapse/events/builder.py | 6 +- synapse/handlers/federation.py | 2 +- synapse/state.py | 209 +++++++++++++++++++++++++++++------------ tests/api/test_filtering.py | 5 +- tests/events/test_utils.py | 22 ++++- 7 files changed, 229 insertions(+), 72 deletions(-) diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 983d8e9a85..3b7726a526 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -27,7 +27,7 @@ from synapse.types import UserID, get_domain_from_id logger = logging.getLogger(__name__) -def check(event, auth_events, do_sig_check=True): +def check(event, auth_events, do_sig_check=True, do_size_check=True): """ Checks if this event is correctly authed. Args: @@ -38,7 +38,8 @@ def check(event, auth_events, do_sig_check=True): Returns: True if the auth checks pass. """ - _check_size_limits(event) + if do_size_check: + _check_size_limits(event) if not hasattr(event, "room_id"): raise AuthError(500, "Event has no room_id: %s" % event) @@ -119,10 +120,11 @@ def check(event, auth_events, do_sig_check=True): ) return True - logger.debug( - "Auth events: %s", - [a.event_id for a in auth_events.values()] - ) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "Auth events: %s", + [a.event_id for a in auth_events.values()] + ) if event.type == EventTypes.Member: allowed = _is_membership_change_allowed( @@ -639,3 +641,38 @@ def get_public_keys(invite_event): public_keys.append(o) public_keys.extend(invite_event.content.get("public_keys", [])) return public_keys + + +def auth_types_for_event(event): + """Given an event, return a list of (EventType, StateKey) that may be + needed to auth the event. The returned list may be a superset of what + would actually be required depending on the full state of the room. + + Used to limit the number of events to fetch from the database to + actually auth the event. + """ + if event.type == EventTypes.Create: + return [] + + auth_types = [] + + auth_types.append((EventTypes.PowerLevels, "", )) + auth_types.append((EventTypes.Member, event.user_id, )) + auth_types.append((EventTypes.Create, "", )) + + if event.type == EventTypes.Member: + e_type = event.content["membership"] + if e_type in [Membership.JOIN, Membership.INVITE]: + auth_types.append((EventTypes.JoinRules, "", )) + + auth_types.append((EventTypes.Member, event.state_key, )) + + if e_type == Membership.INVITE: + if "third_party_invite" in event.content: + key = ( + EventTypes.ThirdPartyInvite, + event.content["third_party_invite"]["signed"]["token"] + ) + auth_types.append(key) + + return auth_types diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index da9f3ad436..e673e96cc0 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -79,7 +79,6 @@ class EventBase(object): auth_events = _event_dict_property("auth_events") depth = _event_dict_property("depth") content = _event_dict_property("content") - event_id = _event_dict_property("event_id") hashes = _event_dict_property("hashes") origin = _event_dict_property("origin") origin_server_ts = _event_dict_property("origin_server_ts") @@ -88,8 +87,6 @@ class EventBase(object): redacts = _event_dict_property("redacts") room_id = _event_dict_property("room_id") sender = _event_dict_property("sender") - state_key = _event_dict_property("state_key") - type = _event_dict_property("type") user_id = _event_dict_property("sender") @property @@ -162,6 +159,11 @@ class FrozenEvent(EventBase): else: frozen_dict = event_dict + self.event_id = event_dict["event_id"] + self.type = event_dict["type"] + if "state_key" in event_dict: + self.state_key = event_dict["state_key"] + super(FrozenEvent, self).__init__( frozen_dict, signatures=signatures, diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 7369d70980..365fd96bd2 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import EventBase, FrozenEvent +from . import EventBase, FrozenEvent, _event_dict_property from synapse.types import EventID @@ -34,6 +34,10 @@ class EventBuilder(EventBase): internal_metadata_dict=internal_metadata_dict, ) + event_id = _event_dict_property("event_id") + state_key = _event_dict_property("state_key") + type = _event_dict_property("type") + def build(self): return FrozenEvent.from_event(self) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 1021bcc405..ea89e0cf2d 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1530,7 +1530,7 @@ class FederationHandler(BaseHandler): (d.type, d.state_key): d for d in different_events if d }) - new_state, prev_state = self.state_handler.resolve_events( + new_state = self.state_handler.resolve_events( [local_view.values(), remote_view.values()], event ) diff --git a/synapse/state.py b/synapse/state.py index 90b14e758c..294e0c2081 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -22,11 +22,10 @@ from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure from synapse.api.constants import EventTypes from synapse.api.errors import AuthError -from synapse.api.auth import AuthEventTypes from synapse.events.snapshot import EventContext from synapse.util.async import Linearizer -from collections import namedtuple +from collections import namedtuple, defaultdict from frozendict import frozendict import logging @@ -48,6 +47,8 @@ EVICTION_TIMEOUT_SECONDS = 60 * 60 _NEXT_STATE_ID = 1 +POWER_KEY = (EventTypes.PowerLevels, "") + def _gen_state_id(): global _NEXT_STATE_ID @@ -328,21 +329,13 @@ class StateHandler(object): if conflicted_state: logger.info("Resolving conflicted state for %r", room_id) - state_map = yield self.store.get_events( - [e_id for st in state_groups_ids.values() for e_id in st.values()], - get_prev_content=False - ) - state_sets = [ - [state_map[e_id] for key, e_id in st.items() if e_id in state_map] - for st in state_groups_ids.values() - ] with Measure(self.clock, "state._resolve_events"): - new_state, _ = resolve_events( - state_sets, event_type, state_key + new_state = yield resolve_events( + state_groups_ids.values(), + state_map_factory=lambda ev_ids: self.store.get_events( + ev_ids, get_prev_content=False + ), ) - new_state = { - key: e.event_id for key, e in new_state.items() - } else: new_state = { key: e_ids.pop() for key, e_ids in state.items() @@ -390,13 +383,25 @@ class StateHandler(object): logger.info( "Resolving state for %s with %d groups", event.room_id, len(state_sets) ) + state_set_ids = [{ + (ev.type, ev.state_key): ev.event_id + for ev in st + } for st in state_sets] + + state_map = { + ev.event_id: ev + for st in state_sets + for ev in st + } + with Measure(self.clock, "state._resolve_events"): - if event.is_state(): - return resolve_events( - state_sets, event.type, event.state_key - ) - else: - return resolve_events(state_sets) + new_state = resolve_events(state_set_ids, state_map) + + new_state = { + key: state_map[ev_id] for key, ev_id in new_state.items() + } + + return new_state def _ordered_events(events): @@ -406,43 +411,117 @@ def _ordered_events(events): return sorted(events, key=key_func) -def resolve_events(state_sets, event_type=None, state_key=""): +def resolve_events(state_sets, state_map_factory): """ + Args: + state_sets(list): List of dicts of (type, state_key) -> event_id, + which are the different state groups to resolve. + state_map_factory(dict|callable): If callable, then will be called + with a list of event_ids that are needed, and should return with + a Deferred of dict of event_id to event. Otherwise, should be + a dict from event_id to event of all events in state_sets. + Returns - (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple - (new_state, prev_states). new_state is a map from (type, state_key) - to event. prev_states is a list of event_ids. + dict[(str, str), synapse.events.FrozenEvent] is a map from + (type, state_key) to event. """ - state = {} - for st in state_sets: - for e in st: - state.setdefault( - (e.type, e.state_key), - {} - )[e.event_id] = e - - unconflicted_state = { - k: v.values()[0] for k, v in state.items() - if len(v.values()) == 1 - } + unconflicted_state, conflicted_state = _seperate( + state_sets, + ) + if callable(state_map_factory): + return _resolve_with_state_fac( + unconflicted_state, conflicted_state, state_map_factory + ) + + state_map = state_map_factory + + auth_events = _create_auth_events_from_maps( + unconflicted_state, conflicted_state, state_map + ) + + return _resolve_with_state( + unconflicted_state, conflicted_state, auth_events, state_map + ) + + +def _seperate(state_sets): + """Takes the state_sets and figures out which keys are conflicted and + which aren't. i.e., which have multiple different event_ids associated + with them in different state sets. + """ + unconflicted_state = dict(state_sets[0]) + conflicted_state = {} + + full_states = defaultdict( + set, + {k: set((v,)) for k, v in state_sets[0].iteritems()} + ) + + for state_set in state_sets[1:]: + for key, value in state_set.iteritems(): + ls = full_states[key] + if not ls: + ls.add(value) + unconflicted_state[key] = value + elif value not in ls: + ls.add(value) + if len(ls) == 2: + conflicted_state[key] = ls + unconflicted_state.pop(key, None) + + return unconflicted_state, conflicted_state + + +@defer.inlineCallbacks +def _resolve_with_state_fac(unconflicted_state, conflicted_state, + state_map_factory): + needed_events = set( + event_id + for event_ids in conflicted_state.itervalues() + for event_id in event_ids + ) + + state_map = yield state_map_factory(needed_events) + + auth_events = _create_auth_events_from_maps( + unconflicted_state, conflicted_state, state_map + ) + + new_needed_events = set(auth_events.itervalues()) + new_needed_events -= needed_events + + state_map_new = yield state_map_factory(new_needed_events) + state_map.update(state_map_new) + + defer.returnValue(_resolve_with_state( + unconflicted_state, conflicted_state, auth_events, state_map + )) + + +def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map): + auth_events = {} + for event_ids in conflicted_state.itervalues(): + for event_id in event_ids: + keys = event_auth.auth_types_for_event(state_map[event_id]) + for key in keys: + if key not in auth_events: + event_id = unconflicted_state.get(key, None) + if event_id: + auth_events[key] = event_id + return auth_events + + +def _resolve_with_state(unconflicted_state, conflicted_state, auth_events, + state_map): conflicted_state = { - k: v.values() - for k, v in state.items() - if len(v.values()) > 1 + key: [state_map[ev_id] for ev_id in event_ids] + for key, event_ids in conflicted_state.items() } - if event_type: - prev_states_events = conflicted_state.get( - (event_type, state_key), [] - ) - prev_states = [s.event_id for s in prev_states_events] - else: - prev_states = [] - auth_events = { - k: e for k, e in unconflicted_state.items() - if k[0] in AuthEventTypes + key: state_map[ev_id] + for key, ev_id in auth_events.items() } try: @@ -454,9 +533,10 @@ def resolve_events(state_sets, event_type=None, state_key=""): raise new_state = unconflicted_state - new_state.update(resolved_state) + for key, event in resolved_state.iteritems(): + new_state[key] = event.event_id - return new_state, prev_states + return new_state def _resolve_state_events(conflicted_state, auth_events): @@ -470,11 +550,10 @@ def _resolve_state_events(conflicted_state, auth_events): 4. other events. """ resolved_state = {} - power_key = (EventTypes.PowerLevels, "") - if power_key in conflicted_state: - events = conflicted_state[power_key] + if POWER_KEY in conflicted_state: + events = conflicted_state[POWER_KEY] logger.debug("Resolving conflicted power levels %r", events) - resolved_state[power_key] = _resolve_auth_events( + resolved_state[POWER_KEY] = _resolve_auth_events( events, auth_events) auth_events.update(resolved_state) @@ -512,14 +591,26 @@ def _resolve_state_events(conflicted_state, auth_events): def _resolve_auth_events(events, auth_events): reverse = [i for i in reversed(_ordered_events(events))] - auth_events = dict(auth_events) + auth_keys = set( + key + for event in events + for key in event_auth.auth_types_for_event(event) + ) + + new_auth_events = {} + for key in auth_keys: + auth_event = auth_events.get(key, None) + if auth_event: + new_auth_events[key] = auth_event + + auth_events = new_auth_events prev_event = reverse[0] for event in reverse[1:]: auth_events[(prev_event.type, prev_event.state_key)] = prev_event try: # The signatures have already been checked at this point - event_auth.check(event, auth_events, do_sig_check=False) + event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) prev_event = event except AuthError: return prev_event @@ -531,7 +622,7 @@ def _resolve_normal_events(events, auth_events): for event in _ordered_events(events): try: # The signatures have already been checked at this point - event_auth.check(event, auth_events, do_sig_check=False) + event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) return event except AuthError: pass diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index dcb6c5bc31..50e8607c14 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -25,10 +25,13 @@ from synapse.api.filtering import Filter from synapse.events import FrozenEvent user_localpart = "test_user" -# MockEvent = namedtuple("MockEvent", "sender type room_id") def MockEvent(**kwargs): + if "event_id" not in kwargs: + kwargs["event_id"] = "fake_event_id" + if "type" not in kwargs: + kwargs["type"] = "fake_type" return FrozenEvent(kwargs) diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 29f068d1f1..dfc870066e 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -21,6 +21,10 @@ from synapse.events.utils import prune_event, serialize_event def MockEvent(**kwargs): + if "event_id" not in kwargs: + kwargs["event_id"] = "fake_event_id" + if "type" not in kwargs: + kwargs["type"] = "fake_type" return FrozenEvent(kwargs) @@ -35,9 +39,13 @@ class PruneEventTestCase(unittest.TestCase): def test_minimal(self): self.run_test( - {'type': 'A'}, { 'type': 'A', + 'event_id': '$test:domain', + }, + { + 'type': 'A', + 'event_id': '$test:domain', 'content': {}, 'signatures': {}, 'unsigned': {}, @@ -69,10 +77,12 @@ class PruneEventTestCase(unittest.TestCase): self.run_test( { 'type': 'B', + 'event_id': '$test:domain', 'unsigned': {'age_ts': 20}, }, { 'type': 'B', + 'event_id': '$test:domain', 'content': {}, 'signatures': {}, 'unsigned': {'age_ts': 20}, @@ -82,10 +92,12 @@ class PruneEventTestCase(unittest.TestCase): self.run_test( { 'type': 'B', + 'event_id': '$test:domain', 'unsigned': {'other_key': 'here'}, }, { 'type': 'B', + 'event_id': '$test:domain', 'content': {}, 'signatures': {}, 'unsigned': {}, @@ -96,10 +108,12 @@ class PruneEventTestCase(unittest.TestCase): self.run_test( { 'type': 'C', + 'event_id': '$test:domain', 'content': {'things': 'here'}, }, { 'type': 'C', + 'event_id': '$test:domain', 'content': {}, 'signatures': {}, 'unsigned': {}, @@ -109,10 +123,12 @@ class PruneEventTestCase(unittest.TestCase): self.run_test( { 'type': 'm.room.create', + 'event_id': '$test:domain', 'content': {'creator': '@2:domain', 'other_field': 'here'}, }, { 'type': 'm.room.create', + 'event_id': '$test:domain', 'content': {'creator': '@2:domain'}, 'signatures': {}, 'unsigned': {}, @@ -255,6 +271,8 @@ class SerializeEventTestCase(unittest.TestCase): self.assertEquals( self.serialize( MockEvent( + type="foo", + event_id="test", room_id="!foo:bar", content={ "foo": "bar", @@ -263,6 +281,8 @@ class SerializeEventTestCase(unittest.TestCase): [] ), { + "type": "foo", + "event_id": "test", "room_id": "!foo:bar", "content": { "foo": "bar", -- cgit 1.4.1 From e6153e1bd10529b28b69820decbc039b9d6a1f27 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 13 Jan 2017 13:21:04 +0000 Subject: Fix couple of federation state bugs --- synapse/federation/federation_client.py | 6 ++++-- synapse/handlers/federation.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index b4bcec77ed..c9175bb33d 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -26,7 +26,7 @@ from synapse.util import unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred -from synapse.events import FrozenEvent +from synapse.events import FrozenEvent, builder import synapse.metrics from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination @@ -499,8 +499,10 @@ class FederationClient(FederationBase): if "prev_state" not in pdu_dict: pdu_dict["prev_state"] = [] + ev = builder.EventBuilder(pdu_dict) + defer.returnValue( - (destination, self.event_from_pdu_json(pdu_dict)) + (destination, ev) ) break except CodeMessageException as e: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index ea89e0cf2d..ced5646e9a 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -596,7 +596,7 @@ class FederationHandler(BaseHandler): preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e]) for e in event_ids ])) - states = dict(zip(event_ids, [s[1] for s in states])) + states = dict(zip(event_ids, [s.state for s in states])) state_map = yield self.store.get_events( [e_id for ids in states.values() for e_id in ids], -- cgit 1.4.1 From 633f97151c6c7fa693b3de4addad641186b4ef02 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 13:33:54 +0000 Subject: Check event is in state_map --- synapse/state.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/synapse/state.py b/synapse/state.py index 294e0c2081..df9b6b3ccb 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -333,7 +333,7 @@ class StateHandler(object): new_state = yield resolve_events( state_groups_ids.values(), state_map_factory=lambda ev_ids: self.store.get_events( - ev_ids, get_prev_content=False + ev_ids, get_prev_content=False, check_redacted=False, ), ) else: @@ -482,6 +482,8 @@ def _resolve_with_state_fac(unconflicted_state, conflicted_state, for event_id in event_ids ) + logger.info("Asking for %d conflicted events", len(needed_events)) + state_map = yield state_map_factory(needed_events) auth_events = _create_auth_events_from_maps( @@ -491,6 +493,8 @@ def _resolve_with_state_fac(unconflicted_state, conflicted_state, new_needed_events = set(auth_events.itervalues()) new_needed_events -= needed_events + logger.info("Asking for %d auth events", len(new_needed_events)) + state_map_new = yield state_map_factory(new_needed_events) state_map.update(state_map_new) @@ -515,13 +519,14 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma def _resolve_with_state(unconflicted_state, conflicted_state, auth_events, state_map): conflicted_state = { - key: [state_map[ev_id] for ev_id in event_ids] + key: [state_map[ev_id] for ev_id in event_ids if ev_id in state_map] for key, event_ids in conflicted_state.items() } auth_events = { key: state_map[ev_id] for key, ev_id in auth_events.items() + if ev_id in state_map } try: -- cgit 1.4.1 From ce59a2faad253409a8047ce9302d3d6c087fe812 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 14:18:53 +0000 Subject: Correctly handle case of rejected events in state res --- synapse/state.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/synapse/state.py b/synapse/state.py index df9b6b3ccb..d2bd1ad646 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -507,21 +507,27 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma auth_events = {} for event_ids in conflicted_state.itervalues(): for event_id in event_ids: - keys = event_auth.auth_types_for_event(state_map[event_id]) - for key in keys: - if key not in auth_events: - event_id = unconflicted_state.get(key, None) - if event_id: - auth_events[key] = event_id + if event_id in state_map: + keys = event_auth.auth_types_for_event(state_map[event_id]) + for key in keys: + if key not in auth_events: + event_id = unconflicted_state.get(key, None) + if event_id: + auth_events[key] = event_id return auth_events def _resolve_with_state(unconflicted_state, conflicted_state, auth_events, state_map): - conflicted_state = { - key: [state_map[ev_id] for ev_id in event_ids if ev_id in state_map] - for key, event_ids in conflicted_state.items() - } + new_conflicted_state = {} + for key, event_ids in conflicted_state.iteritems(): + events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map] + if len(events) > 1: + new_conflicted_state[key] = events + elif len(events) == 1: + unconflicted_state[key] = events[0].event_id + + conflicted_state = new_conflicted_state auth_events = { key: state_map[ev_id] -- cgit 1.4.1 From 04006bb7f014fa62c1534fac7250e7b845fa91d3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 14:31:21 +0000 Subject: Get state at event rather than for room in push --- synapse/push/push_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index b47bf1f92b..a27476bbad 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -52,7 +52,7 @@ def get_badge_count(store, user_id): def get_context_for_event(store, state_handler, ev, user_id): ctx = {} - room_state_ids = yield state_handler.get_current_state_ids(ev.room_id) + room_state_ids = yield store.get_state_ids_for_event(ev.event_id) # we no longer bother setting room_alias, and make room_name the # human-readable name instead, be that m.room.name, an alias or -- cgit 1.4.1 From e5d2df9c3452617e3390b2c356e11b7c49b022b1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 14:32:53 +0000 Subject: Use better variable name --- synapse/event_auth.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 3b7726a526..4096c606f1 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -661,13 +661,13 @@ def auth_types_for_event(event): auth_types.append((EventTypes.Create, "", )) if event.type == EventTypes.Member: - e_type = event.content["membership"] - if e_type in [Membership.JOIN, Membership.INVITE]: + membership = event.content["membership"] + if membership in [Membership.JOIN, Membership.INVITE]: auth_types.append((EventTypes.JoinRules, "", )) auth_types.append((EventTypes.Member, event.state_key, )) - if e_type == Membership.INVITE: + if membership == Membership.INVITE: if "third_party_invite" in event.content: key = ( EventTypes.ThirdPartyInvite, -- cgit 1.4.1 From 37b4c7d8a94203f790c0db408c114ec0004a2cc8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 14:43:32 +0000 Subject: Fix typo in return type --- synapse/util/caches/descriptors.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index a9ea97fd46..675bfd5feb 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -178,9 +178,8 @@ class Cache(object): self.sequence += 1 self.cache.del_multi(key) - val = self._pending_deferred_cache.pop(key, None) - if val is not None: - entry_dict, _ = val + entry_dict = self._pending_deferred_cache.pop(key, None) + if entry_dict is not None: for entry in iterate_tree_cache_entry(entry_dict): entry.invalidate() -- cgit 1.4.1 From a8594fd19f48a179b263d58ba1f9c5ab2f4cb8d3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 14:59:03 +0000 Subject: Use better names --- synapse/state.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/synapse/state.py b/synapse/state.py index d2bd1ad646..6f62876f8c 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -517,21 +517,19 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma return auth_events -def _resolve_with_state(unconflicted_state, conflicted_state, auth_events, +def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ids, state_map): - new_conflicted_state = {} - for key, event_ids in conflicted_state.iteritems(): + conflicted_state = {} + for key, event_ids in conflicted_state_ds.iteritems(): events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map] if len(events) > 1: - new_conflicted_state[key] = events + conflicted_state[key] = events elif len(events) == 1: - unconflicted_state[key] = events[0].event_id - - conflicted_state = new_conflicted_state + unconflicted_state_ids[key] = events[0].event_id auth_events = { key: state_map[ev_id] - for key, ev_id in auth_events.items() + for key, ev_id in auth_event_ids.items() if ev_id in state_map } @@ -543,7 +541,7 @@ def _resolve_with_state(unconflicted_state, conflicted_state, auth_events, logger.exception("Failed to resolve state") raise - new_state = unconflicted_state + new_state = unconflicted_state_ids for key, event in resolved_state.iteritems(): new_state[key] = event.event_id -- cgit 1.4.1 From c6064a7ba6bae6055dc7960e5eef3956131b718d Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 17 Jan 2017 15:23:07 +0000 Subject: Only construct sets when necessary --- synapse/state.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/synapse/state.py b/synapse/state.py index 6f62876f8c..81c6bae737 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -453,22 +453,27 @@ def _seperate(state_sets): unconflicted_state = dict(state_sets[0]) conflicted_state = {} - full_states = defaultdict( - set, - {k: set((v,)) for k, v in state_sets[0].iteritems()} - ) - for state_set in state_sets[1:]: for key, value in state_set.iteritems(): - ls = full_states[key] - if not ls: - ls.add(value) - unconflicted_state[key] = value - elif value not in ls: - ls.add(value) - if len(ls) == 2: - conflicted_state[key] = ls - unconflicted_state.pop(key, None) + # Check if there is an unconflicted entry for the state key. + unconflicted_value = unconflicted_state.get(key) + if unconflicted_value is None: + # There isn't an unconflicted entry so check if there is a + # conflicted entry. + ls = conflicted_state.get(key) + if ls is None: + # There wasn't a conflicted entry so haven't seen this key before. + # Therefore it isn't conflicted yet. + unconflicted_state[key] = value + else: + # This key is already conflicted, add our value to the conflict set. + ls.add(value) + elif unconflicted_value != value: + # If the unconflicted value is not the same as our value then we + # have a new conflict. So move the key from the unconflicted_state + # to the conflicted state. + conflicted_state[key] = {value, unconflicted_value} + unconflicted_state.pop(key, None) return unconflicted_state, conflicted_state -- cgit 1.4.1 From ed4d1761525b21989279b99733e415c1c86ed39f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 15:27:28 +0000 Subject: PEP8 --- synapse/state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/state.py b/synapse/state.py index 81c6bae737..15238cd00c 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -25,7 +25,7 @@ from synapse.api.errors import AuthError from synapse.events.snapshot import EventContext from synapse.util.async import Linearizer -from collections import namedtuple, defaultdict +from collections import namedtuple from frozendict import frozendict import logging -- cgit 1.4.1 From 380dba1020294b2c43ffb433b86917d0ee6cf9c0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 17:04:46 +0000 Subject: Measure metrics of string_cache --- synapse/util/caches/__init__.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index ebd715c5dc..8a7774a88e 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -40,8 +40,8 @@ def register_cache(name, cache): ) -_string_cache = LruCache(int(5000 * CACHE_SIZE_FACTOR)) -caches_by_name["string_cache"] = _string_cache +_string_cache = LruCache(int(100000 * CACHE_SIZE_FACTOR)) +_stirng_cache_metrics = register_cache("string_cache", _string_cache) KNOWN_KEYS = { @@ -69,7 +69,12 @@ KNOWN_KEYS = { def intern_string(string): """Takes a (potentially) unicode string and interns using custom cache """ - return _string_cache.setdefault(string, string) + new_str = _string_cache.setdefault(string, string) + if new_str is string: + _stirng_cache_metrics.inc_hits() + else: + _stirng_cache_metrics.inc_misses() + return new_str def intern_dict(dictionary): -- cgit 1.4.1 From 5f027d1fc54ab51b420e3deb25d83ac05676fdbf Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 17:07:15 +0000 Subject: Change resolve_state_groups call site logging to DEBUG --- synapse/api/auth.py | 2 +- synapse/handlers/federation.py | 2 +- synapse/state.py | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 280d4c4452..03a215ab1b 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -146,7 +146,7 @@ class Auth(object): with Measure(self.clock, "check_host_in_room"): latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - logger.info("calling resolve_state_groups from check_host_in_room") + logger.debug("calling resolve_state_groups from check_host_in_room") entry = yield self.state.resolve_state_groups( room_id, latest_event_ids ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 1021bcc405..2e310fed71 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -591,7 +591,7 @@ class FederationHandler(BaseHandler): event_ids = list(extremities.keys()) - logger.info("calling resolve_state_groups in _maybe_backfill") + logger.debug("calling resolve_state_groups in _maybe_backfill") states = yield preserve_context_over_deferred(defer.gatherResults([ preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e]) for e in event_ids diff --git a/synapse/state.py b/synapse/state.py index 5028b0ac49..1c6e31ac58 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -128,7 +128,7 @@ class StateHandler(object): if not latest_event_ids: latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - logger.info("calling resolve_state_groups from get_current_state") + logger.debug("calling resolve_state_groups from get_current_state") ret = yield self.resolve_state_groups(room_id, latest_event_ids) state = ret.state @@ -153,7 +153,7 @@ class StateHandler(object): if not latest_event_ids: latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - logger.info("calling resolve_state_groups from get_current_state_ids") + logger.debug("calling resolve_state_groups from get_current_state_ids") ret = yield self.resolve_state_groups(room_id, latest_event_ids) state = ret.state @@ -167,7 +167,7 @@ class StateHandler(object): def get_current_user_in_room(self, room_id, latest_event_ids=None): if not latest_event_ids: latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - logger.info("calling resolve_state_groups from get_current_user_in_room") + logger.debug("calling resolve_state_groups from get_current_user_in_room") entry = yield self.resolve_state_groups(room_id, latest_event_ids) joined_users = yield self.store.get_joined_users_from_state( room_id, entry.state_id, entry.state @@ -231,7 +231,7 @@ class StateHandler(object): context.prev_state_events = [] defer.returnValue(context) - logger.info("calling resolve_state_groups from compute_event_context") + logger.debug("calling resolve_state_groups from compute_event_context") if event.is_state(): entry = yield self.resolve_state_groups( event.room_id, [e for e, _ in event.prev_events], -- cgit 1.4.1 From f878f64f4314dae6bd68b11ad1edbf0883f9bd8f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 17:20:39 +0000 Subject: Lower the not retrying host log line to debug --- synapse/federation/transaction_queue.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 7db7b806dc..6b3a7abb9e 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -362,7 +362,7 @@ class TransactionQueue(object): if not success: break except NotRetryingDestination: - logger.info( + logger.debug( "TX [%s] not ready for retry yet - " "dropping transaction for now", destination, -- cgit 1.4.1 From 4ec1cf49e20f35bad2d54575fad23c8e21f8d66f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 17:18:13 +0000 Subject: Lower loading events log to DEBUG --- synapse/storage/events.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 04dbdac3f8..ca501932f3 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1084,10 +1084,10 @@ class EventsStore(SQLBaseStore): self._do_fetch ) - logger.info("Loading %d events", len(events)) + logger.debug("Loading %d events", len(events)) with PreserveLoggingContext(): rows = yield events_d - logger.info("Loaded %d events (%d rows)", len(events), len(rows)) + logger.debug("Loaded %d events (%d rows)", len(events), len(rows)) if not allow_rejected: rows[:] = [r for r in rows if not r["rejects"]] -- cgit 1.4.1 From 8c5009b6282b10b2248f080cd9021a799aad5285 Mon Sep 17 00:00:00 2001 From: David Baker Date: Wed, 18 Jan 2017 13:25:56 +0000 Subject: Lowercase all email addresses before querying db Since we store all emails in the DB in lowercase (https://github.com/matrix-org/synapse/pull/1170) --- synapse/rest/client/v1/login.py | 8 +++++++- synapse/rest/client/v2_alpha/account.py | 5 +++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 093bc072f4..0c9cdff3b8 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -118,8 +118,14 @@ class LoginRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def do_password_login(self, login_submission): if 'medium' in login_submission and 'address' in login_submission: + address = login_submission['address'] + if login_submission['medium'] == 'email': + # For emails, transform the address to lowercase. + # We store all email addreses as lowercase in the DB. + # (See add_threepid in synapse/handlers/auth.py) + address = address.lower() user_id = yield self.hs.get_datastore().get_user_id_by_threepid( - login_submission['medium'], login_submission['address'] + login_submission['medium'], address ) if not user_id: raise LoginError(403, "", errcode=Codes.FORBIDDEN) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index e74e5e0123..398e7f5eb0 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -96,6 +96,11 @@ class PasswordRestServlet(RestServlet): threepid = result[LoginType.EMAIL_IDENTITY] if 'medium' not in threepid or 'address' not in threepid: raise SynapseError(500, "Malformed threepid") + if threepid['medium'] == 'email': + # For emails, transform the address to lowercase. + # We store all email addreses as lowercase in the DB. + # (See add_threepid in synapse/handlers/auth.py) + threepid['address'] = threepid['address'].lower() # if using email, we must know about the email they're authing with! threepid_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( threepid['medium'], threepid['address'] -- cgit 1.4.1 From c430111d0e6efb6a0f929cc3e10f1ce4f32d2c18 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 18 Jan 2017 14:55:23 +0000 Subject: Update LruCache size estimate on clear --- synapse/util/caches/lrucache.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 072f9a9d19..cf5fbb679c 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -189,6 +189,8 @@ class LruCache(object): for cb in node.callbacks: cb() cache.clear() + if size_callback: + cached_cache_len[0] = 0 @synchronized def cache_contains(key): -- cgit 1.4.1 From 1e38be3a7aaea1b6570b27e271855ee380a9129b Mon Sep 17 00:00:00 2001 From: Marvin Steadfast Date: Thu, 19 Jan 2017 14:08:20 +0100 Subject: Added username and password for turn server It makes it possible to use a turn server that needs a username and password instead of a token. --- synapse/config/voip.py | 4 +++- synapse/rest/client/v1/voip.py | 26 +++++++++++++++++--------- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/synapse/config/voip.py b/synapse/config/voip.py index 169980f60d..ef9d61adfc 100644 --- a/synapse/config/voip.py +++ b/synapse/config/voip.py @@ -19,7 +19,9 @@ class VoipConfig(Config): def read_config(self, config): self.turn_uris = config.get("turn_uris", []) - self.turn_shared_secret = config["turn_shared_secret"] + self.turn_shared_secret = config.get("turn_shared_secret") + self.turn_username = config.get("turn_username") + self.turn_password = config.get("turn_password") self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"]) def default_config(self, **kwargs): diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index c40442f958..03141c623c 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -32,18 +32,26 @@ class VoipRestServlet(ClientV1RestServlet): turnUris = self.hs.config.turn_uris turnSecret = self.hs.config.turn_shared_secret + turnUsername = self.hs.config.turn_username + turnPassword = self.hs.config.turn_password userLifetime = self.hs.config.turn_user_lifetime - if not turnUris or not turnSecret or not userLifetime: - defer.returnValue((200, {})) - expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000 - username = "%d:%s" % (expiry, requester.user.to_string()) + if turnUris and turnSecret and userLifetime: + expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000 + username = "%d:%s" % (expiry, requester.user.to_string()) + + mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1) + # We need to use standard padded base64 encoding here + # encode_base64 because we need to add the standard padding to get the + # same result as the TURN server. + password = base64.b64encode(mac.digest()) - mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1) - # We need to use standard padded base64 encoding here - # encode_base64 because we need to add the standard padding to get the - # same result as the TURN server. - password = base64.b64encode(mac.digest()) + elif turnUris and turnUsername and turnPassword and userLifetime: + username = turnUsername + password = turnPassword + + else: + defer.returnValue((200, {})) defer.returnValue((200, { 'username': username, -- cgit 1.4.1 From 86e616568793f4b208137ed61add2d5aba9d6c43 Mon Sep 17 00:00:00 2001 From: Marvin Steadfast Date: Thu, 19 Jan 2017 14:35:55 +0100 Subject: Added default config for turn username and password --- synapse/config/voip.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/synapse/config/voip.py b/synapse/config/voip.py index ef9d61adfc..eeb693027b 100644 --- a/synapse/config/voip.py +++ b/synapse/config/voip.py @@ -34,6 +34,11 @@ class VoipConfig(Config): # The shared secret used to compute passwords for the TURN server turn_shared_secret: "YOUR_SHARED_SECRET" + # The Username and password if the TURN server needs them and + # does not use a token + #turn_username: "TURNSERVER_USERNAME" + #turn_password: "TURNSERVER_PASSWORD" + # How long generated TURN credentials last turn_user_lifetime: "1h" """ -- cgit 1.4.1 From 97efe99ae964e8f4e866d961282257e6f4293fd8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 20 Jan 2017 11:45:29 +0000 Subject: Make worker listener config backwards compat --- synapse/config/workers.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 904789d155..b165c67ee7 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -29,3 +29,13 @@ class WorkerConfig(Config): self.worker_log_file = config.get("worker_log_file") self.worker_log_config = config.get("worker_log_config") self.worker_replication_url = config.get("worker_replication_url") + + if self.worker_listeners: + for listener in self.worker_listeners: + bind_address = listener.pop("bind_address", None) + bind_addresses = listener.setdefault("bind_addresses", []) + + if bind_address: + bind_addresses.append(bind_address) + elif not bind_addresses: + bind_addresses.append('') -- cgit 1.4.1 From 09eb08f910bd4a6077cca6ab4c3068eee55d59f3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 20 Jan 2017 11:52:51 +0000 Subject: Derive current_state_events from state groups --- synapse/handlers/federation.py | 1 - synapse/state.py | 3 + synapse/storage/events.py | 188 ++++++++++++++++--------- tests/replication/slave/storage/test_events.py | 45 +++--- 4 files changed, 138 insertions(+), 99 deletions(-) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index d3f5892376..996bfd0e23 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1319,7 +1319,6 @@ class FederationHandler(BaseHandler): event_stream_id, max_stream_id = yield self.store.persist_event( event, new_event_context, - current_state=state, ) defer.returnValue((event_stream_id, max_stream_id)) diff --git a/synapse/state.py b/synapse/state.py index 20aaacf40f..383d32b163 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -429,6 +429,9 @@ def resolve_events(state_sets, state_map_factory): dict[(str, str), synapse.events.FrozenEvent] is a map from (type, state_key) to event. """ + if len(state_sets) == 1: + return state_sets[0] + unconflicted_state, conflicted_state = _seperate( state_sets, ) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index ca501932f3..0d6519f30d 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, _RollbackButIsFineException +from ._base import SQLBaseStore from twisted.internet import defer, reactor @@ -27,6 +27,7 @@ from synapse.util.logutils import log_function from synapse.util.metrics import Measure from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError +from synapse.state import resolve_events from canonicaljson import encode_canonical_json from collections import deque, namedtuple, OrderedDict @@ -71,22 +72,19 @@ class _EventPeristenceQueue(object): """ _EventPersistQueueItem = namedtuple("_EventPersistQueueItem", ( - "events_and_contexts", "current_state", "backfilled", "deferred", + "events_and_contexts", "backfilled", "deferred", )) def __init__(self): self._event_persist_queues = {} self._currently_persisting_rooms = set() - def add_to_queue(self, room_id, events_and_contexts, backfilled, current_state): + def add_to_queue(self, room_id, events_and_contexts, backfilled): """Add events to the queue, with the given persist_event options. """ queue = self._event_persist_queues.setdefault(room_id, deque()) if queue: end_item = queue[-1] - if end_item.current_state or current_state: - # We perist events with current_state set to True one at a time - pass if end_item.backfilled == backfilled: end_item.events_and_contexts.extend(events_and_contexts) return end_item.deferred.observe() @@ -96,7 +94,6 @@ class _EventPeristenceQueue(object): queue.append(self._EventPersistQueueItem( events_and_contexts=events_and_contexts, backfilled=backfilled, - current_state=current_state, deferred=deferred, )) @@ -216,7 +213,6 @@ class EventsStore(SQLBaseStore): d = preserve_fn(self._event_persist_queue.add_to_queue)( room_id, evs_ctxs, backfilled=backfilled, - current_state=None, ) deferreds.append(d) @@ -229,11 +225,10 @@ class EventsStore(SQLBaseStore): @defer.inlineCallbacks @log_function - def persist_event(self, event, context, current_state=None, backfilled=False): + def persist_event(self, event, context, backfilled=False): deferred = self._event_persist_queue.add_to_queue( event.room_id, [(event, context)], backfilled=backfilled, - current_state=current_state, ) self._maybe_start_persisting(event.room_id) @@ -246,21 +241,10 @@ class EventsStore(SQLBaseStore): def _maybe_start_persisting(self, room_id): @defer.inlineCallbacks def persisting_queue(item): - if item.current_state: - for event, context in item.events_and_contexts: - # There should only ever be one item in - # events_and_contexts when current_state is - # not None - yield self._persist_event( - event, context, - current_state=item.current_state, - backfilled=item.backfilled, - ) - else: - yield self._persist_events( - item.events_and_contexts, - backfilled=item.backfilled, - ) + yield self._persist_events( + item.events_and_contexts, + backfilled=item.backfilled, + ) self._event_persist_queue.handle_queue(room_id, persisting_queue) @@ -294,36 +278,89 @@ class EventsStore(SQLBaseStore): for chunk in chunks: # We can't easily parallelize these since different chunks # might contain the same event. :( + + current_state_for_room = {} + if not backfilled: + # Work out the new "current state" for each room. + # We do this by working out what the new extremities are and then + # calculating the state from that. + events_by_room = {} + for event, context in chunk: + events_by_room.setdefault(event.room_id, []).append( + (event, context) + ) + + for room_id, ev_ctx_rm in events_by_room.items(): + # Work out new extremities by recursively adding and removing + # the new events. + latest_event_ids = yield self.get_latest_event_ids_in_room( + room_id + ) + new_latest_event_ids = set(latest_event_ids) + for event, ctx in ev_ctx_rm: + if event.internal_metadata.is_outlier(): + continue + + new_latest_event_ids.difference_update( + e_id for e_id, _ in event.prev_events + ) + new_latest_event_ids.add(event.event_id) + + if new_latest_event_ids == set(latest_event_ids): + # No change in extremities, so no change in state + continue + + # Now we need to work out the different state sets for + # each state extremities + state_sets = [] + missing_event_ids = [] + was_updated = False + for event_id in new_latest_event_ids: + # First search in the list of new events we're adding, + # and then use the current state from that + for ev, ctx in ev_ctx_rm: + if event_id == ev.event_id: + if ctx.current_state_ids is None: + raise Exception("Unknown current state") + state_sets.append(ctx.current_state_ids) + if ctx.delta_ids or hasattr(ev, "state_key"): + was_updated = True + break + else: + # If we couldn't find it, then we'll need to pull + # the state from the database + was_updated = True + missing_event_ids.append(event_id) + + if missing_event_ids: + # Now pull out the state for any missing events from DB + event_to_groups = yield self._get_state_group_for_events( + missing_event_ids, + ) + + groups = set(event_to_groups.values()) + group_to_state = yield self._get_state_for_groups(groups) + + state_sets.extend(group_to_state.values()) + + if not new_latest_event_ids or was_updated: + current_state_for_room[room_id] = yield resolve_events( + state_sets, + state_map_factory=lambda ev_ids: self.get_events( + ev_ids, get_prev_content=False, check_redacted=False, + ), + ) + yield self.runInteraction( "persist_events", self._persist_events_txn, events_and_contexts=chunk, backfilled=backfilled, delete_existing=delete_existing, + current_state_for_room=current_state_for_room, ) persist_event_counter.inc_by(len(chunk)) - @_retry_on_integrity_error - @defer.inlineCallbacks - @log_function - def _persist_event(self, event, context, current_state=None, backfilled=False, - delete_existing=False): - try: - with self._stream_id_gen.get_next() as stream_ordering: - event.internal_metadata.stream_ordering = stream_ordering - yield self.runInteraction( - "persist_event", - self._persist_event_txn, - event=event, - context=context, - current_state=current_state, - backfilled=backfilled, - delete_existing=delete_existing, - ) - persist_event_counter.inc() - except _RollbackButIsFineException: - pass - @defer.inlineCallbacks def get_event(self, event_id, check_redacted=True, get_prev_content=False, allow_rejected=False, @@ -426,7 +463,7 @@ class EventsStore(SQLBaseStore): @log_function def _persist_events_txn(self, txn, events_and_contexts, backfilled, - delete_existing=False): + delete_existing=False, current_state_for_room={}): """Insert some number of room events into the necessary database tables. Rejected events are only inserted into the events table, the events_json table, @@ -436,6 +473,40 @@ class EventsStore(SQLBaseStore): If delete_existing is True then existing events will be purged from the database before insertion. This is useful when retrying due to IntegrityError. """ + for room_id, current_state in current_state_for_room.iteritems(): + txn.call_after(self._get_current_state_for_key.invalidate_all) + txn.call_after(self.get_rooms_for_user.invalidate_all) + txn.call_after(self.get_users_in_room.invalidate, (room_id,)) + + # Add an entry to the current_state_resets table to record the point + # where we clobbered the current state + stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering + self._simple_insert_txn( + txn, + table="current_state_resets", + values={"event_stream_ordering": stream_order} + ) + + self._simple_delete_txn( + txn, + table="current_state_events", + keyvalues={"room_id": room_id}, + ) + + self._simple_insert_many_txn( + txn, + table="current_state_events", + values=[ + { + "event_id": ev_id, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + } + for key, ev_id in current_state.iteritems() + ], + ) + # Ensure that we don't have the same event twice. # Pick the earliest non-outlier if there is one, else the earliest one. new_events_and_contexts = OrderedDict() @@ -798,29 +869,6 @@ class EventsStore(SQLBaseStore): # to update the current state table return - for event, _ in state_events_and_contexts: - if event.internal_metadata.is_outlier(): - # Outlier events shouldn't clobber the current state. - continue - - txn.call_after( - self._get_current_state_for_key.invalidate, - (event.room_id, event.type, event.state_key,) - ) - - self._simple_upsert_txn( - txn, - "current_state_events", - keyvalues={ - "room_id": event.room_id, - "type": event.type, - "state_key": event.state_key, - }, - values={ - "event_id": event.event_id, - } - ) - return def _add_to_cache(self, txn, events_and_contexts): diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 44e859b5d1..38fedfe690 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -60,7 +60,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): @defer.inlineCallbacks def test_room_members(self): - create = yield self.persist(type="m.room.create", key="", creator=USER_ID) + yield self.persist(type="m.room.create", key="", creator=USER_ID) yield self.replicate() yield self.check("get_rooms_for_user", (USER_ID,), []) yield self.check("get_users_in_room", (ROOM_ID,), []) @@ -95,15 +95,11 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): )]) yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2]) - # Join the room clobbering the state. - # This should remove any evidence of the other user being in the room. yield self.persist( type="m.room.member", key=USER_ID, membership="join", - reset_state=[create] ) yield self.replicate() - yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID]) - yield self.check("get_rooms_for_user", (USER_ID_2,), []) + yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2, USER_ID]) @defer.inlineCallbacks def test_get_latest_event_ids_in_room(self): @@ -125,7 +121,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): @defer.inlineCallbacks def test_get_current_state(self): # Create the room. - create = yield self.persist(type="m.room.create", key="", creator=USER_ID) + yield self.persist(type="m.room.create", key="", creator=USER_ID) yield self.replicate() yield self.check( "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), [] @@ -151,22 +147,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): [join2] ) - # Leave the room, then rejoin the room clobbering state. - yield self.persist(type="m.room.member", key=USER_ID, membership="leave") - join3 = yield self.persist( - type="m.room.member", key=USER_ID, membership="join", - reset_state=[create] - ) - yield self.replicate() - yield self.check( - "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2), - [] - ) - yield self.check( - "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), - [join3] - ) - @defer.inlineCallbacks def test_redactions(self): yield self.persist(type="m.room.create", key="", creator=USER_ID) @@ -283,6 +263,12 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): if depth is None: depth = self.event_id + if not prev_events: + latest_event_ids = yield self.master_store.get_latest_event_ids_in_room( + room_id + ) + prev_events = [(ev_id, {}) for ev_id in latest_event_ids] + event_dict = { "sender": sender, "type": type, @@ -309,12 +295,15 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): state_ids = { key: e.event_id for key, e in state.items() } + context = EventContext() + context.current_state_ids = state_ids + context.prev_state_ids = state_ids + elif not backfill: + state_handler = self.hs.get_state_handler() + context = yield state_handler.compute_event_context(event) else: - state_ids = None + context = EventContext() - context = EventContext() - context.current_state_ids = state_ids - context.prev_state_ids = state_ids context.push_actions = push_actions ordering = None @@ -324,7 +313,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): ) else: ordering, _ = yield self.master_store.persist_event( - event, context, current_state=reset_state + event, context, ) if ordering: -- cgit 1.4.1 From 83333498a565248bf33ce602b5687a7595ce47b1 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 20 Jan 2017 12:15:50 +0000 Subject: fix doc for purge_media_cache purge_media_cache takes its arg from a query-param, not the POST body, for some reason. --- docs/admin_api/purge_remote_media.rst | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/docs/admin_api/purge_remote_media.rst b/docs/admin_api/purge_remote_media.rst index b26c6a9e7b..5deb02a3df 100644 --- a/docs/admin_api/purge_remote_media.rst +++ b/docs/admin_api/purge_remote_media.rst @@ -2,15 +2,13 @@ Purge Remote Media API ====================== The purge remote media API allows server admins to purge old cached remote -media. +media. The API is:: - POST /_matrix/client/r0/admin/purge_media_cache + POST /_matrix/client/r0/admin/purge_media_cache?before_ts=&access_token= - { - "before_ts": - } + {} Which will remove all cached media that was last accessed before ````. -- cgit 1.4.1 From 4c6a31cd6efa25be4c9f1b357e8f92065fac63eb Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 20 Jan 2017 14:28:53 +0000 Subject: Calculate the forward extremeties once --- synapse/storage/event_federation.py | 76 ++----------------- synapse/storage/events.py | 142 ++++++++++++++++++++++-------------- 2 files changed, 92 insertions(+), 126 deletions(-) diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 53feaa1960..f0aa2193fb 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -235,80 +235,21 @@ class EventFederationStore(SQLBaseStore): ], ) - self._update_extremeties(txn, events) + self._update_backward_extremeties(txn, events) - def _update_extremeties(self, txn, events): - """Updates the event_*_extremities tables based on the new/updated + def _update_backward_extremeties(self, txn, events): + """Updates the event_backward_extremities tables based on the new/updated events being persisted. This is called for new events *and* for events that were outliers, but - are are now being persisted as non-outliers. + are now being persisted as non-outliers. + + Forward extremities are handled when we first start persisting the events. """ events_by_room = {} for ev in events: events_by_room.setdefault(ev.room_id, []).append(ev) - for room_id, room_events in events_by_room.items(): - prevs = [ - e_id for ev in room_events for e_id, _ in ev.prev_events - if not ev.internal_metadata.is_outlier() - ] - if prevs: - txn.execute( - "DELETE FROM event_forward_extremities" - " WHERE room_id = ?" - " AND event_id in (%s)" % ( - ",".join(["?"] * len(prevs)), - ), - [room_id] + prevs, - ) - - query = ( - "INSERT INTO event_forward_extremities (event_id, room_id)" - " SELECT ?, ? WHERE NOT EXISTS (" - " SELECT 1 FROM event_edges WHERE prev_event_id = ?" - " )" - ) - - txn.executemany( - query, - [ - (ev.event_id, ev.room_id, ev.event_id) for ev in events - if not ev.internal_metadata.is_outlier() - ] - ) - - # We now insert into stream_ordering_to_exterm a mapping from room_id, - # new stream_ordering to new forward extremeties in the room. - # This allows us to later efficiently look up the forward extremeties - # for a room before a given stream_ordering - max_stream_ord = max( - ev.internal_metadata.stream_ordering for ev in events - ) - new_extrem = {} - for room_id in events_by_room: - event_ids = self._simple_select_onecol_txn( - txn, - table="event_forward_extremities", - keyvalues={"room_id": room_id}, - retcol="event_id", - ) - new_extrem[room_id] = event_ids - - self._simple_insert_many_txn( - txn, - table="stream_ordering_to_exterm", - values=[ - { - "room_id": room_id, - "event_id": event_id, - "stream_ordering": max_stream_ord, - } - for room_id, extrem_evs in new_extrem.items() - for event_id in extrem_evs - ] - ) - query = ( "INSERT INTO event_backward_extremities (event_id, room_id)" " SELECT ?, ? WHERE NOT EXISTS (" @@ -339,11 +280,6 @@ class EventFederationStore(SQLBaseStore): ] ) - for room_id in events_by_room: - txn.call_after( - self.get_latest_event_ids_in_room.invalidate, (room_id,) - ) - def get_forward_extremeties_for_room(self, room_id, stream_ordering): # We want to make the cache more effective, so we clamp to the last # change before the given ordering. diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 0d6519f30d..295f2522b5 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -279,6 +279,7 @@ class EventsStore(SQLBaseStore): # We can't easily parallelize these since different chunks # might contain the same event. :( + new_forward_extremeties = {} current_state_for_room = {} if not backfilled: # Work out the new "current state" for each room. @@ -296,20 +297,16 @@ class EventsStore(SQLBaseStore): latest_event_ids = yield self.get_latest_event_ids_in_room( room_id ) - new_latest_event_ids = set(latest_event_ids) - for event, ctx in ev_ctx_rm: - if event.internal_metadata.is_outlier(): - continue - - new_latest_event_ids.difference_update( - e_id for e_id, _ in event.prev_events - ) - new_latest_event_ids.add(event.event_id) + new_latest_event_ids = yield self._calculate_new_extremeties( + room_id, [ev for ev, _ in ev_ctx_rm] + ) if new_latest_event_ids == set(latest_event_ids): # No change in extremities, so no change in state continue + new_forward_extremeties[room_id] = new_latest_event_ids + # Now we need to work out the different state sets for # each state extremities state_sets = [] @@ -358,9 +355,45 @@ class EventsStore(SQLBaseStore): backfilled=backfilled, delete_existing=delete_existing, current_state_for_room=current_state_for_room, + new_forward_extremeties=new_forward_extremeties, ) persist_event_counter.inc_by(len(chunk)) + @defer.inlineCallbacks + def _calculate_new_extremeties(self, room_id, events): + latest_event_ids = yield self.get_latest_event_ids_in_room( + room_id + ) + new_latest_event_ids = set(latest_event_ids) + new_latest_event_ids.update( + event.event_id for event in events + if not event.internal_metadata.is_outlier() + ) + new_latest_event_ids.difference_update( + e_id + for event in events + for e_id, _ in event.prev_events + if not event.internal_metadata.is_outlier() + ) + + rows = yield self._simple_select_many_batch( + table="event_edges", + column="prev_event_id", + iterable=list(new_latest_event_ids), + retcols=["prev_event_id"], + keyvalues={ + "room_id": room_id, + "is_state": False, + }, + desc="_calculate_new_extremeties", + ) + + new_latest_event_ids.difference_update( + row["prev_event_id"] for row in rows + ) + + defer.returnValue(new_latest_event_ids) + @defer.inlineCallbacks def get_event(self, event_id, check_redacted=True, get_prev_content=False, allow_rejected=False, @@ -417,53 +450,10 @@ class EventsStore(SQLBaseStore): defer.returnValue({e.event_id: e for e in events}) - @log_function - def _persist_event_txn(self, txn, event, context, current_state, backfilled=False, - delete_existing=False): - # We purposefully do this first since if we include a `current_state` - # key, we *want* to update the `current_state_events` table - if current_state: - txn.call_after(self._get_current_state_for_key.invalidate_all) - txn.call_after(self.get_rooms_for_user.invalidate_all) - txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) - - # Add an entry to the current_state_resets table to record the point - # where we clobbered the current state - stream_order = event.internal_metadata.stream_ordering - self._simple_insert_txn( - txn, - table="current_state_resets", - values={"event_stream_ordering": stream_order} - ) - - self._simple_delete_txn( - txn, - table="current_state_events", - keyvalues={"room_id": event.room_id}, - ) - - for s in current_state: - self._simple_insert_txn( - txn, - "current_state_events", - { - "event_id": s.event_id, - "room_id": s.room_id, - "type": s.type, - "state_key": s.state_key, - } - ) - - return self._persist_events_txn( - txn, - [(event, context)], - backfilled=backfilled, - delete_existing=delete_existing, - ) - @log_function def _persist_events_txn(self, txn, events_and_contexts, backfilled, - delete_existing=False, current_state_for_room={}): + delete_existing=False, current_state_for_room={}, + new_forward_extremeties={}): """Insert some number of room events into the necessary database tables. Rejected events are only inserted into the events table, the events_json table, @@ -473,6 +463,7 @@ class EventsStore(SQLBaseStore): If delete_existing is True then existing events will be purged from the database before insertion. This is useful when retrying due to IntegrityError. """ + max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering for room_id, current_state in current_state_for_room.iteritems(): txn.call_after(self._get_current_state_for_key.invalidate_all) txn.call_after(self.get_rooms_for_user.invalidate_all) @@ -480,11 +471,10 @@ class EventsStore(SQLBaseStore): # Add an entry to the current_state_resets table to record the point # where we clobbered the current state - stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering self._simple_insert_txn( txn, table="current_state_resets", - values={"event_stream_ordering": stream_order} + values={"event_stream_ordering": max_stream_order} ) self._simple_delete_txn( @@ -507,6 +497,46 @@ class EventsStore(SQLBaseStore): ], ) + for room_id, new_extrem in new_forward_extremeties.items(): + self._simple_delete_txn( + txn, + table="event_forward_extremities", + keyvalues={"room_id": room_id}, + ) + txn.call_after( + self.get_latest_event_ids_in_room.invalidate, (room_id,) + ) + + self._simple_insert_many_txn( + txn, + table="event_forward_extremities", + values=[ + { + "event_id": ev_id, + "room_id": room_id, + } + for room_id, new_extrem in new_forward_extremeties.items() + for ev_id in new_extrem + ], + ) + # We now insert into stream_ordering_to_exterm a mapping from room_id, + # new stream_ordering to new forward extremeties in the room. + # This allows us to later efficiently look up the forward extremeties + # for a room before a given stream_ordering + self._simple_insert_many_txn( + txn, + table="stream_ordering_to_exterm", + values=[ + { + "room_id": room_id, + "event_id": event_id, + "stream_ordering": max_stream_order, + } + for room_id, new_extrem in new_forward_extremeties.items() + for event_id in new_extrem + ] + ) + # Ensure that we don't have the same event twice. # Pick the earliest non-outlier if there is one, else the earliest one. new_events_and_contexts = OrderedDict() -- cgit 1.4.1 From f2f40e64a9e6965c44bad5bd2222644c2ab0a868 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 20 Jan 2017 14:38:13 +0000 Subject: Comments --- synapse/storage/events.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 295f2522b5..48a4931889 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -279,6 +279,8 @@ class EventsStore(SQLBaseStore): # We can't easily parallelize these since different chunks # might contain the same event. :( + # NB: Assumes that we are only persisting events for one room + # at a time. new_forward_extremeties = {} current_state_for_room = {} if not backfilled: @@ -361,14 +363,21 @@ class EventsStore(SQLBaseStore): @defer.inlineCallbacks def _calculate_new_extremeties(self, room_id, events): + """Caculates the new forward extremeties for a room given events to + persist. + + Assumes that we are only persisting events for one room at a time. + """ latest_event_ids = yield self.get_latest_event_ids_in_room( room_id ) new_latest_event_ids = set(latest_event_ids) + # First, add all the new events to the list new_latest_event_ids.update( event.event_id for event in events if not event.internal_metadata.is_outlier() ) + # Now remove all events that are referenced by the to-be-added events new_latest_event_ids.difference_update( e_id for event in events @@ -376,6 +385,8 @@ class EventsStore(SQLBaseStore): if not event.internal_metadata.is_outlier() ) + # And finally remove any events that are referenced by previously added + # events. rows = yield self._simple_select_many_batch( table="event_edges", column="prev_event_id", -- cgit 1.4.1 From 567aa35b67e9dbe343dbce0459836baf0388486c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 20 Jan 2017 14:40:31 +0000 Subject: Update all call sites after rename --- synapse/storage/events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 48a4931889..82351f38a5 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -662,7 +662,7 @@ class EventsStore(SQLBaseStore): # Update the event_backward_extremities table now that this # event isn't an outlier any more. - self._update_extremeties(txn, [event]) + self._update_backward_extremeties(txn, [event]) events_and_contexts = [ ec for ec in events_and_contexts if ec[0] not in to_remove -- cgit 1.4.1 From d0897dead575b3b8c1adac2f8d18a33a4be8e793 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 20 Jan 2017 15:05:11 +0000 Subject: Spelling --- synapse/storage/events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 82351f38a5..6160949f32 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -363,7 +363,7 @@ class EventsStore(SQLBaseStore): @defer.inlineCallbacks def _calculate_new_extremeties(self, room_id, events): - """Caculates the new forward extremeties for a room given events to + """Calculates the new forward extremeties for a room given events to persist. Assumes that we are only persisting events for one room at a time. -- cgit 1.4.1 From a55fa2047f813d639e2a0beed0c2d2738b0b639b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 20 Jan 2017 15:40:04 +0000 Subject: Insert delta of current_state_events to be more efficient --- synapse/handlers/_base.py | 8 ++- synapse/replication/slave/storage/events.py | 10 ---- synapse/storage/events.py | 78 +++++++++++++++++--------- synapse/storage/state.py | 52 ----------------- tests/replication/slave/storage/test_events.py | 29 ---------- 5 files changed, 58 insertions(+), 119 deletions(-) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 90f96209f8..e83adc8339 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -88,9 +88,13 @@ class BaseHandler(object): current_state = yield self.store.get_events( context.current_state_ids.values() ) - current_state = current_state.values() else: - current_state = yield self.store.get_current_state(event.room_id) + current_state = yield self.state_handler.get_current_state( + event.room_id + ) + + current_state = current_state.values() + logger.info("maybe_kick_guest_users %r", current_state) yield self.kick_guest_users(current_state) diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 64f18bbb3e..b3f3bf7488 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -76,9 +76,6 @@ class SlavedEventStore(BaseSlavedStore): get_latest_event_ids_in_room = EventFederationStore.__dict__[ "get_latest_event_ids_in_room" ] - _get_current_state_for_key = StateStore.__dict__[ - "_get_current_state_for_key" - ] get_invited_rooms_for_user = RoomMemberStore.__dict__[ "get_invited_rooms_for_user" ] @@ -115,8 +112,6 @@ class SlavedEventStore(BaseSlavedStore): ) get_event = DataStore.get_event.__func__ get_events = DataStore.get_events.__func__ - get_current_state = DataStore.get_current_state.__func__ - get_current_state_for_key = DataStore.get_current_state_for_key.__func__ get_rooms_for_user_where_membership_is = ( DataStore.get_rooms_for_user_where_membership_is.__func__ ) @@ -248,7 +243,6 @@ class SlavedEventStore(BaseSlavedStore): def invalidate_caches_for_event(self, event, backfilled, reset_state): if reset_state: - self._get_current_state_for_key.invalidate_all() self.get_rooms_for_user.invalidate_all() self.get_users_in_room.invalidate((event.room_id,)) @@ -289,7 +283,3 @@ class SlavedEventStore(BaseSlavedStore): if (not event.internal_metadata.is_invite_from_remote() and event.internal_metadata.is_outlier()): return - - self._get_current_state_for_key.invalidate(( - event.room_id, event.type, event.state_key - )) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 6160949f32..9f57760ab0 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -476,37 +476,63 @@ class EventsStore(SQLBaseStore): """ max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering for room_id, current_state in current_state_for_room.iteritems(): - txn.call_after(self._get_current_state_for_key.invalidate_all) - txn.call_after(self.get_rooms_for_user.invalidate_all) - txn.call_after(self.get_users_in_room.invalidate, (room_id,)) - - # Add an entry to the current_state_resets table to record the point - # where we clobbered the current state - self._simple_insert_txn( - txn, - table="current_state_resets", - values={"event_stream_ordering": max_stream_order} - ) - - self._simple_delete_txn( + existing_state_rows = self._simple_select_list_txn( txn, table="current_state_events", keyvalues={"room_id": room_id}, + retcols=["event_id", "type", "state_key"], ) - self._simple_insert_many_txn( - txn, - table="current_state_events", - values=[ - { - "event_id": ev_id, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - } - for key, ev_id in current_state.iteritems() - ], - ) + existing_events = set(row["event_id"] for row in existing_state_rows) + new_events = set(ev_id for ev_id in current_state.itervalues()) + changed_events = existing_events ^ new_events + if changed_events: + txn.executemany( + "DELETE FROM current_state_events WHERE event_id = ?", + [(ev_id,) for ev_id in changed_events], + ) + + # Add an entry to the current_state_resets table to record the point + # where we clobbered the current state + self._simple_insert_txn( + txn, + table="current_state_resets", + values={"event_stream_ordering": max_stream_order} + ) + + events_to_insert = (new_events - existing_events) + to_insert = [ + (key, ev_id) for key, ev_id in current_state.iteritems() + if ev_id in events_to_insert + ] + self._simple_insert_many_txn( + txn, + table="current_state_events", + values=[ + { + "event_id": ev_id, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + } + for key, ev_id in to_insert + ], + ) + + members_changed = set( + row["state_key"] for row in existing_state_rows + if row["event_id"] in changed_events + and row["type"] == EventTypes.Member + ) + members_changed.update( + key[1] for key, event_id in to_insert + if key[0] == EventTypes.Member + ) + + for member in members_changed: + txn.call_after(self.get_rooms_for_user.invalidate, (member,)) + + txn.call_after(self.get_users_in_room.invalidate, (room_id,)) for room_id, new_extrem in new_forward_extremeties.items(): self._simple_delete_txn( diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 7d34dd03bf..d1d653327c 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -232,58 +232,6 @@ class StateStore(SQLBaseStore): return count - @defer.inlineCallbacks - def get_current_state(self, room_id, event_type=None, state_key=""): - if event_type and state_key is not None: - result = yield self.get_current_state_for_key( - room_id, event_type, state_key - ) - defer.returnValue(result) - - def f(txn): - sql = ( - "SELECT event_id FROM current_state_events" - " WHERE room_id = ? " - ) - - if event_type and state_key is not None: - sql += " AND type = ? AND state_key = ? " - args = (room_id, event_type, state_key) - elif event_type: - sql += " AND type = ?" - args = (room_id, event_type) - else: - args = (room_id, ) - - txn.execute(sql, args) - results = txn.fetchall() - - return [r[0] for r in results] - - event_ids = yield self.runInteraction("get_current_state", f) - events = yield self._get_events(event_ids, get_prev_content=False) - defer.returnValue(events) - - @defer.inlineCallbacks - def get_current_state_for_key(self, room_id, event_type, state_key): - event_ids = yield self._get_current_state_for_key(room_id, event_type, state_key) - events = yield self._get_events(event_ids, get_prev_content=False) - defer.returnValue(events) - - @cached(num_args=3) - def _get_current_state_for_key(self, room_id, event_type, state_key): - def f(txn): - sql = ( - "SELECT event_id FROM current_state_events" - " WHERE room_id = ? AND type = ? AND state_key = ?" - ) - - args = (room_id, event_type, state_key) - txn.execute(sql, args) - results = txn.fetchall() - return [r[0] for r in results] - return self.runInteraction("get_current_state_for_key", f) - @cached(num_args=2, max_entries=100000, iterable=True) def _get_state_group_from_group(self, group, types): raise NotImplementedError() diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 38fedfe690..6acb8ab758 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -118,35 +118,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): "get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id] ) - @defer.inlineCallbacks - def test_get_current_state(self): - # Create the room. - yield self.persist(type="m.room.create", key="", creator=USER_ID) - yield self.replicate() - yield self.check( - "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), [] - ) - - # Join the room. - join1 = yield self.persist( - type="m.room.member", key=USER_ID, membership="join", - ) - yield self.replicate() - yield self.check( - "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), - [join1] - ) - - # Add some other user to the room. - join2 = yield self.persist( - type="m.room.member", key=USER_ID_2, membership="join", - ) - yield self.replicate() - yield self.check( - "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2), - [join2] - ) - @defer.inlineCallbacks def test_redactions(self): yield self.persist(type="m.room.create", key="", creator=USER_ID) -- cgit 1.4.1 From 5d2134d485be6c9e4e5881d099ad7dee5229e1cd Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 20 Jan 2017 17:13:24 +0000 Subject: Comments --- synapse/storage/events.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 9f57760ab0..2fd9f4045b 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -483,6 +483,10 @@ class EventsStore(SQLBaseStore): retcols=["event_id", "type", "state_key"], ) + # Figure out what has changed (if anything). Then we simply delete + # and readd the keys that have been changed. + # This saves us from deleting and reinserting thousands of rows for + # large rooms. existing_events = set(row["event_id"] for row in existing_state_rows) new_events = set(ev_id for ev_id in current_state.itervalues()) changed_events = existing_events ^ new_events @@ -492,14 +496,6 @@ class EventsStore(SQLBaseStore): [(ev_id,) for ev_id in changed_events], ) - # Add an entry to the current_state_resets table to record the point - # where we clobbered the current state - self._simple_insert_txn( - txn, - table="current_state_resets", - values={"event_stream_ordering": max_stream_order} - ) - events_to_insert = (new_events - existing_events) to_insert = [ (key, ev_id) for key, ev_id in current_state.iteritems() @@ -519,6 +515,13 @@ class EventsStore(SQLBaseStore): ], ) + # Invalidate the various caches + + # Figure out the changes of membership to invalidate the + # `get_rooms_for_user` cache. + # We find out which membership events we may have deleted + # and which we have added, then we invlidate the caches for all + # those users. members_changed = set( row["state_key"] for row in existing_state_rows if row["event_id"] in changed_events @@ -534,6 +537,14 @@ class EventsStore(SQLBaseStore): txn.call_after(self.get_users_in_room.invalidate, (room_id,)) + # Add an entry to the current_state_resets table to record the point + # where we clobbered the current state + self._simple_insert_txn( + txn, + table="current_state_resets", + values={"event_stream_ordering": max_stream_order} + ) + for room_id, new_extrem in new_forward_extremeties.items(): self._simple_delete_txn( txn, -- cgit 1.4.1 From c77b24c092e8ba585aa304422fe497707ca6af35 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 Jan 2017 14:51:33 +0000 Subject: Refactor to calculate state delta outside transaction --- synapse/storage/events.py | 205 ++++++++++++++++++++++++++-------------------- 1 file changed, 118 insertions(+), 87 deletions(-) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 2fd9f4045b..599db4c9f0 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -284,71 +284,37 @@ class EventsStore(SQLBaseStore): new_forward_extremeties = {} current_state_for_room = {} if not backfilled: - # Work out the new "current state" for each room. - # We do this by working out what the new extremities are and then - # calculating the state from that. - events_by_room = {} - for event, context in chunk: - events_by_room.setdefault(event.room_id, []).append( - (event, context) - ) - - for room_id, ev_ctx_rm in events_by_room.items(): - # Work out new extremities by recursively adding and removing - # the new events. - latest_event_ids = yield self.get_latest_event_ids_in_room( - room_id - ) - new_latest_event_ids = yield self._calculate_new_extremeties( - room_id, [ev for ev, _ in ev_ctx_rm] - ) - - if new_latest_event_ids == set(latest_event_ids): - # No change in extremities, so no change in state - continue + with Measure(self._clock, "_calculate_state_and_extrem"): + # Work out the new "current state" for each room. + # We do this by working out what the new extremities are and then + # calculating the state from that. + events_by_room = {} + for event, context in chunk: + events_by_room.setdefault(event.room_id, []).append( + (event, context) + ) - new_forward_extremeties[room_id] = new_latest_event_ids - - # Now we need to work out the different state sets for - # each state extremities - state_sets = [] - missing_event_ids = [] - was_updated = False - for event_id in new_latest_event_ids: - # First search in the list of new events we're adding, - # and then use the current state from that - for ev, ctx in ev_ctx_rm: - if event_id == ev.event_id: - if ctx.current_state_ids is None: - raise Exception("Unknown current state") - state_sets.append(ctx.current_state_ids) - if ctx.delta_ids or hasattr(ev, "state_key"): - was_updated = True - break - else: - # If we couldn't find it, then we'll need to pull - # the state from the database - was_updated = True - missing_event_ids.append(event_id) - - if missing_event_ids: - # Now pull out the state for any missing events from DB - event_to_groups = yield self._get_state_group_for_events( - missing_event_ids, + for room_id, ev_ctx_rm in events_by_room.items(): + # Work out new extremities by recursively adding and removing + # the new events. + latest_event_ids = yield self.get_latest_event_ids_in_room( + room_id + ) + new_latest_event_ids = yield self._calculate_new_extremeties( + room_id, [ev for ev, _ in ev_ctx_rm] ) - groups = set(event_to_groups.values()) - group_to_state = yield self._get_state_for_groups(groups) + if new_latest_event_ids == set(latest_event_ids): + # No change in extremities, so no change in state + continue - state_sets.extend(group_to_state.values()) + new_forward_extremeties[room_id] = new_latest_event_ids - if not new_latest_event_ids or was_updated: - current_state_for_room[room_id] = yield resolve_events( - state_sets, - state_map_factory=lambda ev_ids: self.get_events( - ev_ids, get_prev_content=False, check_redacted=False, - ), + state = yield self._calculate_state_delta( + room_id, ev_ctx_rm, new_latest_event_ids ) + if state: + current_state_for_room[room_id] = state yield self.runInteraction( "persist_events", @@ -405,6 +371,91 @@ class EventsStore(SQLBaseStore): defer.returnValue(new_latest_event_ids) + @defer.inlineCallbacks + def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids): + """Calculate the new state deltas for a room. + + Assumes that we are only persisting events for one room at a time. + + Returns: + 2-tuple (to_delete, to_insert) where both are state dicts, i.e. + (type, state_key) -> event_id. `to_delete` are the entreis to + first be deleted from current_state_events, `to_insert` are entries + to insert. + May return None if there are no changes to be applied. + """ + # Now we need to work out the different state sets for + # each state extremities + state_sets = [] + missing_event_ids = [] + was_updated = False + for event_id in new_latest_event_ids: + # First search in the list of new events we're adding, + # and then use the current state from that + for ev, ctx in events_context: + if event_id == ev.event_id: + if ctx.current_state_ids is None: + raise Exception("Unknown current state") + state_sets.append(ctx.current_state_ids) + if ctx.delta_ids or hasattr(ev, "state_key"): + was_updated = True + break + else: + # If we couldn't find it, then we'll need to pull + # the state from the database + was_updated = True + missing_event_ids.append(event_id) + + if missing_event_ids: + # Now pull out the state for any missing events from DB + event_to_groups = yield self._get_state_group_for_events( + missing_event_ids, + ) + + groups = set(event_to_groups.values()) + group_to_state = yield self._get_state_for_groups(groups) + + state_sets.extend(group_to_state.values()) + + if not new_latest_event_ids: + current_state = {} + elif was_updated: + current_state = yield resolve_events( + state_sets, + state_map_factory=lambda ev_ids: self.get_events( + ev_ids, get_prev_content=False, check_redacted=False, + ), + ) + else: + return + + existing_state_rows = yield self._simple_select_list( + table="current_state_events", + keyvalues={"room_id": room_id}, + retcols=["event_id", "type", "state_key"], + desc="_calculate_state_delta", + ) + + existing_events = set(row["event_id"] for row in existing_state_rows) + new_events = set(ev_id for ev_id in current_state.itervalues()) + changed_events = existing_events ^ new_events + + if not changed_events: + return + + to_delete = { + (row["type"], row["state_key"]): row["event_id"] + for row in existing_state_rows + if row["event_id"] in changed_events + } + events_to_insert = (new_events - existing_events) + to_insert = { + key: ev_id for key, ev_id in current_state.iteritems() + if ev_id in events_to_insert + } + + defer.returnValue((to_delete, to_insert)) + @defer.inlineCallbacks def get_event(self, event_id, check_redacted=True, get_prev_content=False, allow_rejected=False, @@ -475,32 +526,13 @@ class EventsStore(SQLBaseStore): database before insertion. This is useful when retrying due to IntegrityError. """ max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering - for room_id, current_state in current_state_for_room.iteritems(): - existing_state_rows = self._simple_select_list_txn( - txn, - table="current_state_events", - keyvalues={"room_id": room_id}, - retcols=["event_id", "type", "state_key"], - ) - - # Figure out what has changed (if anything). Then we simply delete - # and readd the keys that have been changed. - # This saves us from deleting and reinserting thousands of rows for - # large rooms. - existing_events = set(row["event_id"] for row in existing_state_rows) - new_events = set(ev_id for ev_id in current_state.itervalues()) - changed_events = existing_events ^ new_events - if changed_events: + for room_id, current_state_tuple in current_state_for_room.iteritems(): + to_delete, to_insert = current_state_tuple txn.executemany( "DELETE FROM current_state_events WHERE event_id = ?", - [(ev_id,) for ev_id in changed_events], + [(ev_id,) for ev_id in to_delete.itervalues()], ) - events_to_insert = (new_events - existing_events) - to_insert = [ - (key, ev_id) for key, ev_id in current_state.iteritems() - if ev_id in events_to_insert - ] self._simple_insert_many_txn( txn, table="current_state_events", @@ -511,7 +543,7 @@ class EventsStore(SQLBaseStore): "type": key[0], "state_key": key[1], } - for key, ev_id in to_insert + for key, ev_id in to_insert.iteritems() ], ) @@ -523,13 +555,12 @@ class EventsStore(SQLBaseStore): # and which we have added, then we invlidate the caches for all # those users. members_changed = set( - row["state_key"] for row in existing_state_rows - if row["event_id"] in changed_events - and row["type"] == EventTypes.Member + state_key for ev_type, state_key in to_delete.iterkeys() + if ev_type == EventTypes.Member ) members_changed.update( - key[1] for key, event_id in to_insert - if key[0] == EventTypes.Member + state_key for ev_type, state_key in to_insert.iterkeys() + if ev_type == EventTypes.Member ) for member in members_changed: -- cgit 1.4.1 From fdf2a31a51874d5a42087caa88a08d46161801f1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 Jan 2017 16:14:14 +0000 Subject: Typo --- synapse/storage/events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 599db4c9f0..8712d7e18c 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -379,7 +379,7 @@ class EventsStore(SQLBaseStore): Returns: 2-tuple (to_delete, to_insert) where both are state dicts, i.e. - (type, state_key) -> event_id. `to_delete` are the entreis to + (type, state_key) -> event_id. `to_delete` are the entries to first be deleted from current_state_events, `to_insert` are entries to insert. May return None if there are no changes to be applied. -- cgit 1.4.1 From 10e48d83107492ebf872f8cd8051b95c35adc9f3 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Tue, 24 Jan 2017 18:06:07 +0000 Subject: Don't clobber a displayname or avatar_url if provided by an m.room.member event --- synapse/handlers/message.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 88bd2d572e..7a498af5a2 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -208,8 +208,10 @@ class MessageHandler(BaseHandler): content = builder.content try: - content["displayname"] = yield profile.get_displayname(target) - content["avatar_url"] = yield profile.get_avatar_url(target) + if "displayname" not in content: + content["displayname"] = yield profile.get_displayname(target) + if "avatar_url" not in content: + content["avatar_url"] = yield profile.get_avatar_url(target) except Exception as e: logger.info( "Failed to get profile information for %r: %s", -- cgit 1.4.1 From 2367c5568c01bc65aacc955b76ba707918b37f1e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 25 Jan 2017 14:27:27 +0000 Subject: Add basic implementation of local device list changes --- synapse/federation/transaction_queue.py | 24 ++- synapse/handlers/device.py | 65 ++++++-- synapse/handlers/e2e_keys.py | 1 + synapse/handlers/sync.py | 13 ++ synapse/rest/client/v2_alpha/sync.py | 6 +- synapse/storage/__init__.py | 11 ++ synapse/storage/_base.py | 6 + synapse/storage/devices.py | 169 +++++++++++++++++++-- synapse/storage/end_to_end_keys.py | 23 ++- .../schema/delta/40/device_list_streams.sql | 56 +++++++ synapse/streams/events.py | 4 + synapse/types.py | 2 + tests/handlers/test_typing.py | 3 + tests/rest/client/v1/test_rooms.py | 4 +- 14 files changed, 348 insertions(+), 39 deletions(-) create mode 100644 synapse/storage/schema/delta/40/device_list_streams.sql diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 6b3a7abb9e..65c6673a87 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -100,6 +100,7 @@ class TransactionQueue(object): self.pending_failures_by_dest = {} self.last_device_stream_id_by_dest = {} + self.last_device_list_stream_id_by_dest = {} # HACK to get unique tx id self._next_txn_id = int(self.clock.time_msec()) @@ -356,7 +357,7 @@ class TransactionQueue(object): success = yield self._send_new_transaction( destination, pending_pdus, pending_edus, pending_failures, device_stream_id, - should_delete_from_device_stream=bool(device_message_edus), + includes_device_messages=bool(device_message_edus), limiter=limiter, ) if not success: @@ -373,6 +374,8 @@ class TransactionQueue(object): @defer.inlineCallbacks def _get_new_device_messages(self, destination): + # TODO: Send appropriate device list messages + last_device_stream_id = self.last_device_stream_id_by_dest.get(destination, 0) to_device_stream_id = self.store.get_to_device_stream_token() contents, stream_id = yield self.store.get_new_device_msgs_for_remote( @@ -387,13 +390,27 @@ class TransactionQueue(object): ) for content in contents ] + + last_device_list = self.last_device_list_stream_id_by_dest.get(destination, 0) + now_stream_id, results = yield self.store.get_devices_by_remote( + destination, last_device_list + ) + edus.extend( + Edu( + origin=self.server_name, + destination=destination, + edu_type="m.device_list_update", + content=content, + ) + for content in results + ) defer.returnValue((edus, stream_id)) @measure_func("_send_new_transaction") @defer.inlineCallbacks def _send_new_transaction(self, destination, pending_pdus, pending_edus, pending_failures, device_stream_id, - should_delete_from_device_stream, limiter): + includes_device_messages, limiter): # Sort based on the order field pending_pdus.sort(key=lambda t: t[1]) @@ -506,7 +523,8 @@ class TransactionQueue(object): success = False else: # Remove the acknowledged device messages from the database - if should_delete_from_device_stream: + # Only bother if we actually sent some device messages + if includes_device_messages: yield self.store.delete_device_msgs_for_remote( destination, device_stream_id ) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index aa68755936..d92780b642 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -15,6 +15,7 @@ from synapse.api import errors from synapse.util import stringutils +from synapse.types import get_domain_from_id from twisted.internet import defer from ._base import BaseHandler @@ -27,6 +28,8 @@ class DeviceHandler(BaseHandler): def __init__(self, hs): super(DeviceHandler, self).__init__(hs) + self.state = hs.get_state_handler() + @defer.inlineCallbacks def check_device_registered(self, user_id, device_id, initial_device_display_name=None): @@ -45,29 +48,29 @@ class DeviceHandler(BaseHandler): str: device id (generated if none was supplied) """ if device_id is not None: - yield self.store.store_device( + new_device = yield self.store.store_device( user_id=user_id, device_id=device_id, initial_device_display_name=initial_device_display_name, - ignore_if_known=True, ) + if new_device: + yield self.notify_device_update(user_id, device_id) defer.returnValue(device_id) # if the device id is not specified, we'll autogen one, but loop a few # times in case of a clash. attempts = 0 while attempts < 5: - try: - device_id = stringutils.random_string(10).upper() - yield self.store.store_device( - user_id=user_id, - device_id=device_id, - initial_device_display_name=initial_device_display_name, - ignore_if_known=False, - ) + device_id = stringutils.random_string(10).upper() + new_device = yield self.store.store_device( + user_id=user_id, + device_id=device_id, + initial_device_display_name=initial_device_display_name, + ) + if new_device: + yield self.notify_device_update(user_id, device_id) defer.returnValue(device_id) - except errors.StoreError: - attempts += 1 + attempts += 1 raise errors.StoreError(500, "Couldn't generate a device ID.") @@ -147,6 +150,8 @@ class DeviceHandler(BaseHandler): user_id=user_id, device_id=device_id ) + yield self.notify_device_update(user_id, device_id) + @defer.inlineCallbacks def update_device(self, user_id, device_id, content): """ Update the given device @@ -166,12 +171,48 @@ class DeviceHandler(BaseHandler): device_id, new_display_name=content.get("display_name") ) + yield self.notify_device_update(user_id, device_id) except errors.StoreError, e: if e.code == 404: raise errors.NotFoundError() else: raise + @defer.inlineCallbacks + def notify_device_update(self, user_id, device_id): + rooms = yield self.store.get_rooms_for_user(user_id) + room_ids = [r.room_id for r in rooms] + + hosts = set() + for room_id in room_ids: + users = yield self.state.get_current_user_in_room(room_id) + hosts.update(get_domain_from_id(u) for u in users) + hosts.discard(self.server_name) + + position = yield self.store.add_device_change_to_streams( + user_id, device_id, list(hosts) + ) + + yield self.notifier.on_new_event( + "device_list_key", position, rooms=room_ids, + ) + + for host in hosts: + self.federation.send_device_messages(host) + + @defer.inlineCallbacks + def get_device_list_changes(self, user_id, room_ids, from_key): + room_ids = frozenset(room_ids) + + user_ids_changed = set() + changed = yield self.store.get_user_whose_devices_changed(from_key) + for other_user_id in changed: + other_rooms = yield self.store.get_rooms_for_user(other_user_id) + if room_ids.intersection(e.room_id for e in other_rooms): + user_ids_changed.add(other_user_id) + + defer.returnValue(user_ids_changed) + def _update_device_from_client_ips(device, client_ips): ip = client_ips.get((device["user_id"], device["device_id"]), {}) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index b63a660c06..38c2a2d39e 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -259,6 +259,7 @@ class E2eKeysHandler(object): user_id, device_id, time_now, encode_canonical_json(device_keys) ) + yield self.device_handler.notify_device_update(user_id, device_id) one_time_keys = keys.get("one_time_keys", None) if one_time_keys: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c880f61685..06bf626367 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -115,6 +115,7 @@ class SyncResult(collections.namedtuple("SyncResult", [ "invited", # InvitedSyncResult for each invited room. "archived", # ArchivedSyncResult for each archived room. "to_device", # List of direct messages for the device. + "device_lists", # List of user_ids whose devices have chanegd ])): __slots__ = [] @@ -143,6 +144,7 @@ class SyncHandler(object): self.clock = hs.get_clock() self.response_cache = ResponseCache(hs) self.state = hs.get_state_handler() + self.device_handler = hs.get_device_handler() def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, full_state=False): @@ -544,6 +546,16 @@ class SyncHandler(object): yield self._generate_sync_entry_for_to_device(sync_result_builder) + if since_token and since_token.device_list_key: + user_id = sync_config.user.to_string() + rooms = yield self.store.get_rooms_for_user(user_id) + joined_room_ids = set(r.room_id for r in rooms) + device_lists = yield self.device_handler.get_device_list_changes( + user_id, joined_room_ids, since_token.device_list_key + ) + else: + device_lists = [] + defer.returnValue(SyncResult( presence=sync_result_builder.presence, account_data=sync_result_builder.account_data, @@ -551,6 +563,7 @@ class SyncHandler(object): invited=sync_result_builder.invited, archived=sync_result_builder.archived, to_device=sync_result_builder.to_device, + device_lists=device_lists, next_batch=sync_result_builder.now_token, )) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 7199ec883a..b3d8001638 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -170,12 +170,16 @@ class SyncRestServlet(RestServlet): ) archived = self.encode_archived( - sync_result.archived, time_now, requester.access_token_id, filter.event_fields + sync_result.archived, time_now, requester.access_token_id, + filter.event_fields, ) response_content = { "account_data": {"events": sync_result.account_data}, "to_device": {"events": sync_result.to_device}, + "device_lists": { + "changed": list(sync_result.device_lists), + }, "presence": self.encode_presence( sync_result.presence, time_now ), diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index e8495f1eb9..b9968debe5 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -116,6 +116,9 @@ class DataStore(RoomMemberStore, RoomStore, self._public_room_id_gen = StreamIdGenerator( db_conn, "public_room_list_stream", "stream_id" ) + self._device_list_id_gen = StreamIdGenerator( + db_conn, "device_lists_stream", "stream_id", + ) self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") @@ -210,6 +213,14 @@ class DataStore(RoomMemberStore, RoomStore, prefilled_cache=device_outbox_prefill, ) + device_list_max = self._device_list_id_gen.get_current_token() + self._device_list_stream_cache = StreamChangeCache( + "DeviceListStreamChangeCache", device_list_max, + ) + self._device_list_federation_stream_cache = StreamChangeCache( + "DeviceListFederationStreamChangeCache", device_list_max, + ) + cur = LoggingTransaction( db_conn.cursor(), name="_find_stream_orderings_for_times_txn", diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 963ef999d5..05374682fd 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -387,6 +387,10 @@ class SQLBaseStore(object): Args: table : string giving the table name values : dict of new column names and values for them + + Returns: + bool: Whether the row was inserted or not. Only useful when + `or_ignore` is True """ try: yield self.runInteraction( @@ -398,6 +402,8 @@ class SQLBaseStore(object): # a cursor after we receive an error from the db. if not or_ignore: raise + defer.returnValue(False) + defer.returnValue(True) @staticmethod def _simple_insert_txn(txn, table, values): diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 17920d4480..b594f501f9 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import ujson as json from twisted.internet import defer @@ -33,17 +34,13 @@ class DeviceStore(SQLBaseStore): user_id (str): id of user associated with the device device_id (str): id of device initial_device_display_name (str): initial displayname of the - device - ignore_if_known (bool): ignore integrity errors which mean the - device is already known + device. Ignored if device exists. Returns: - defer.Deferred - Raises: - StoreError: if ignore_if_known is False and the device was already - known + defer.Deferred: boolean whether the device was inserted or an + existing device existed with that ID. """ try: - yield self._simple_insert( + inserted = yield self._simple_insert( "devices", values={ "user_id": user_id, @@ -51,8 +48,9 @@ class DeviceStore(SQLBaseStore): "display_name": initial_device_display_name }, desc="store_device", - or_ignore=ignore_if_known, + or_ignore=True, ) + defer.returnValue(inserted) except Exception as e: logger.error("store_device with device_id=%s(%r) user_id=%s(%r)" " display_name=%s(%r) failed: %s", @@ -139,3 +137,156 @@ class DeviceStore(SQLBaseStore): ) defer.returnValue({d["device_id"]: d for d in devices}) + + def get_devices_by_remote(self, destination, from_stream_id): + now_stream_id = self._device_list_id_gen.get_current_token() + + has_changed = self._device_list_stream_cache.has_entity_changed( + destination, int(from_stream_id) + ) + if not has_changed: + defer.returnValue((now_stream_id, [])) + + return self.runInteraction( + "get_devices_by_remote", self._get_devices_by_remote_txn, + destination, from_stream_id, now_stream_id, + ) + + def _get_devices_by_remote_txn(self, txn, destination, from_stream_id, + now_stream_id): + sql = """ + SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes + WHERE destination = ? AND stream_id > ? AND stream_id <= ? AND sent = ? + GROUP BY user_id, device_id + """ + txn.execute( + sql, (destination, from_stream_id, now_stream_id, False) + ) + rows = txn.fetchall() + + if not rows: + return now_stream_id, [] + + # maps (user_id, device_id) -> stream_id + query_map = {(r[0], r[1]): r[2] for r in rows} + devices = self._get_e2e_device_keys_txn( + txn, query_map.keys(), include_all_devices=True + ) + + prev_sent_id_sql = """ + SELECT coalesce(max(stream_id), 0) as stream_id + FROM device_lists_outbound_pokes + WHERE destination = ? AND user_id = ? AND sent = ? + """ + + results = [] + for user_id, user_devices in devices.iteritems(): + txn.execute(prev_sent_id_sql, (destination, user_id, True)) + rows = txn.fetchall() + prev_id = rows[0][0] + for device_id, result in user_devices.iteritems(): + stream_id = query_map[(user_id, device_id)] + result = { + "user_id": user_id, + "device_id": device_id, + "prev_id": prev_id, + "stream_id": stream_id, + } + + prev_id = stream_id + + key_json = result.get("key_json", None) + if key_json: + result["keys"] = json.loads(key_json) + device_display_name = result.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + + results.setdefault(user_id, {})[device_id] = result + + return now_stream_id, results + + def mark_as_sent_devices_by_remote(self, destination, stream_id): + return self.runInteraction( + "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, + destination, stream_id, + ) + + @defer.inlineCallbacks + def get_user_whose_devices_changed(self, from_key): + from_key = int(from_key) + changed = self._device_list_stream_cache.get_all_entities_changed(from_key) + if changed is not None: + defer.returnValue(set(changed)) + + sql = """ + SELECT user_id FROM device_lists_stream WHERE stream_id > ? + """ + rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key) + defer.returnValue(set(row["user_id"] for row in rows)) + + def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): + sql = """ + DELETE FROM device_lists_outbound_pokes + WHERE destination = ? AND stream_id < ( + SELECT coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes + WHERE destination = ? AND stream_id <= ? + ) + """ + txn.execute(sql, (destination, destination, stream_id,)) + + sql = """ + UPDATE device_lists_outbound_pokes SET sent = ? + WHERE destination = ? AND stream_id <= ? + """ + txn.execute(sql, (destination, True,)) + + @defer.inlineCallbacks + def add_device_change_to_streams(self, user_id, device_id, hosts): + # device_lists_stream + # device_lists_outbound_pokes + with self._device_list_id_gen.get_next() as stream_id: + yield self.runInteraction( + "add_device_change_to_streams", self._add_device_change_txn, + user_id, device_id, hosts, stream_id, + ) + defer.returnValue(stream_id) + + def _add_device_change_txn(self, txn, user_id, device_id, hosts, stream_id): + txn.call_after( + self._device_list_stream_cache.entity_has_changed, + user_id, stream_id, + ) + for host in hosts: + txn.call_after( + self._device_list_federation_stream_cache.entity_has_changed, + host, stream_id, + ) + + self._simple_insert_txn( + txn, + table="device_lists_stream", + values={ + "stream_id": stream_id, + "user_id": user_id, + "device_id": device_id, + } + ) + + self._simple_insert_many_txn( + txn, + table="device_lists_outbound_pokes", + values=[ + { + "destination": destination, + "stream_id": stream_id, + "user_id": user_id, + "device_id": device_id, + "sent": False, + } + for destination in hosts + ] + ) + + def get_device_stream_token(self): + return self._device_list_id_gen.get_current_token() diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 385d607056..f82943a7a8 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -12,9 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import collections - -import twisted.internet.defer +from twisted.internet import defer from ._base import SQLBaseStore @@ -33,7 +31,7 @@ class EndToEndKeyStore(SQLBaseStore): } ) - def get_e2e_device_keys(self, query_list): + def get_e2e_device_keys(self, query_list, include_all_devices=False): """Fetch a list of device keys. Args: query_list(list): List of pairs of user_ids and device_ids. @@ -45,10 +43,11 @@ class EndToEndKeyStore(SQLBaseStore): return {} return self.runInteraction( - "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list + "get_e2e_device_keys", self._get_e2e_device_keys_txn, + query_list, include_all_devices, ) - def _get_e2e_device_keys_txn(self, txn, query_list): + def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices): query_clauses = [] query_params = [] @@ -63,23 +62,23 @@ class EndToEndKeyStore(SQLBaseStore): query_clauses.append(query_clause) sql = ( - "SELECT k.user_id, k.device_id, " + "SELECT user_id, device_id, " " d.display_name AS device_display_name, " " k.key_json" " FROM e2e_device_keys_json k" - " LEFT JOIN devices d ON d.user_id = k.user_id" - " AND d.device_id = k.device_id" + " %s JOIN devices d USING (user_id, device_id)" " WHERE %s" ) % ( + "FULL OUTER" if include_all_devices else "LEFT", " OR ".join("(" + q + ")" for q in query_clauses) ) txn.execute(sql, query_params) rows = self.cursor_to_dict(txn) - result = collections.defaultdict(dict) + result = {} for row in rows: - result[row["user_id"]][row["device_id"]] = row + result.setdefault(row["user_id"], {})[row["device_id"]] = row return result @@ -152,7 +151,7 @@ class EndToEndKeyStore(SQLBaseStore): "claim_e2e_one_time_keys", _claim_e2e_one_time_keys ) - @twisted.internet.defer.inlineCallbacks + @defer.inlineCallbacks def delete_e2e_keys_by_device(self, user_id, device_id): yield self._simple_delete( table="e2e_device_keys_json", diff --git a/synapse/storage/schema/delta/40/device_list_streams.sql b/synapse/storage/schema/delta/40/device_list_streams.sql new file mode 100644 index 0000000000..61cac63bbb --- /dev/null +++ b/synapse/storage/schema/delta/40/device_list_streams.sql @@ -0,0 +1,56 @@ +/* Copyright 2017 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE device_list_streams_remote ( + list_id TEXT NOT NULL, + origin TEXT NOT NULL, + user_id TEXT NOT NULL, + is_full BOOLEAN NOT NULL, + ts BIGINT NOT NULL +); + +CREATE INDEX device_list_streams_remote_id_origin ON device_list_streams_remote( + origin, list_id, user_id +); + + +CREATE TABLE device_lists_remote_cache ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + content TEXT NOT NULL +); + +CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id); + + +CREATE TABLE device_lists_stream ( + stream_id BIGINT NOT NULL, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL +); + +CREATE INDEX device_lists_stream_id ON device_lists_stream(stream_id, user_id); + + +CREATE TABLE device_lists_outbound_pokes ( + destination TEXT NOT NULL, + stream_id BIGINT NOT NULL, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + sent BOOLEAN NOT NULL +); + +CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes(destination, stream_id); +CREATE INDEX device_lists_outbound_pokes_user ON device_lists_outbound_pokes(destination, user_id); diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 4d44c3d4ca..91a59b0bae 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -44,6 +44,7 @@ class EventSources(object): def get_current_token(self): push_rules_key, _ = self.store.get_push_rules_stream_token() to_device_key = self.store.get_to_device_stream_token() + device_list_key = self.store.get_device_stream_token() token = StreamToken( room_key=( @@ -63,6 +64,7 @@ class EventSources(object): ), push_rules_key=push_rules_key, to_device_key=to_device_key, + device_list_key=device_list_key, ) defer.returnValue(token) @@ -70,6 +72,7 @@ class EventSources(object): def get_current_token_for_room(self, room_id): push_rules_key, _ = self.store.get_push_rules_stream_token() to_device_key = self.store.get_to_device_stream_token() + device_list_key = self.store.get_device_stream_token() token = StreamToken( room_key=( @@ -89,5 +92,6 @@ class EventSources(object): ), push_rules_key=push_rules_key, to_device_key=to_device_key, + device_list_key=device_list_key, ) defer.returnValue(token) diff --git a/synapse/types.py b/synapse/types.py index 3a3ab21d17..9666f9d73f 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -158,6 +158,7 @@ class StreamToken( "account_data_key", "push_rules_key", "to_device_key", + "device_list_key", )) ): _SEPARATOR = "_" @@ -195,6 +196,7 @@ class StreamToken( or (int(other.account_data_key) < int(self.account_data_key)) or (int(other.push_rules_key) < int(self.push_rules_key)) or (int(other.to_device_key) < int(self.to_device_key)) + or (int(other.device_list_key) < int(self.device_list_key)) ) def copy_and_advance(self, key, new_value): diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index c718d1f98f..f88d2be7c5 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -75,6 +75,7 @@ class TypingNotificationsTestCase(unittest.TestCase): "get_received_txn_response", "set_received_txn_response", "get_destination_retry_timings", + "get_devices_by_remote", ]), state_handler=self.state_handler, handlers=None, @@ -99,6 +100,8 @@ class TypingNotificationsTestCase(unittest.TestCase): defer.succeed(retry_timings_res) ) + self.datastore.get_devices_by_remote.return_value = (0, []) + def get_received_txn_response(*args): return defer.succeed(None) self.datastore.get_received_txn_response = get_received_txn_response diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 6bce352c5f..d746ea8568 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1032,7 +1032,7 @@ class RoomMessageListTestCase(RestTestCase): @defer.inlineCallbacks def test_topo_token_is_accepted(self): - token = "t1-0_0_0_0_0_0_0" + token = "t1-0_0_0_0_0_0_0_0" (code, response) = yield self.mock_resource.trigger_get( "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)) @@ -1044,7 +1044,7 @@ class RoomMessageListTestCase(RestTestCase): @defer.inlineCallbacks def test_stream_token_is_accepted_for_fwd_pagianation(self): - token = "s0_0_0_0_0_0_0" + token = "s0_0_0_0_0_0_0_0" (code, response) = yield self.mock_resource.trigger_get( "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)) -- cgit 1.4.1 From 51e9fe36e46331ac611cec1d4cb425c1bc98721c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 25 Jan 2017 16:55:21 +0000 Subject: Fix up sending of m.device_list_update edus --- synapse/federation/transaction_queue.py | 121 ++++++++++++++++---------------- synapse/handlers/device.py | 1 + synapse/storage/devices.py | 40 +++++------ 3 files changed, 82 insertions(+), 80 deletions(-) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 65c6673a87..d18f6b6cfd 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -306,62 +306,74 @@ class TransactionQueue(object): yield run_on_reactor() while True: - pending_pdus = self.pending_pdus_by_dest.pop(destination, []) - pending_edus = self.pending_edus_by_dest.pop(destination, []) - pending_presence = self.pending_presence_by_dest.pop(destination, {}) - pending_failures = self.pending_failures_by_dest.pop(destination, []) + pending_pdus = self.pending_pdus_by_dest.pop(destination, []) + pending_edus = self.pending_edus_by_dest.pop(destination, []) + pending_presence = self.pending_presence_by_dest.pop(destination, {}) + pending_failures = self.pending_failures_by_dest.pop(destination, []) - pending_edus.extend( - self.pending_edus_keyed_by_dest.pop(destination, {}).values() - ) + pending_edus.extend( + self.pending_edus_keyed_by_dest.pop(destination, {}).values() + ) - limiter = yield get_retry_limiter( - destination, - self.clock, - self.store, - ) + limiter = yield get_retry_limiter( + destination, + self.clock, + self.store, + ) - device_message_edus, device_stream_id = ( - yield self._get_new_device_messages(destination) - ) + device_message_edus, device_stream_id, dev_list_id = ( + yield self._get_new_device_messages(destination) + ) - pending_edus.extend(device_message_edus) - if pending_presence: - pending_edus.append( - Edu( - origin=self.server_name, - destination=destination, - edu_type="m.presence", - content={ - "push": [ - format_user_presence_state( - presence, self.clock.time_msec() - ) - for presence in pending_presence.values() - ] - }, - ) + pending_edus.extend(device_message_edus) + if pending_presence: + pending_edus.append( + Edu( + origin=self.server_name, + destination=destination, + edu_type="m.presence", + content={ + "push": [ + format_user_presence_state( + presence, self.clock.time_msec() + ) + for presence in pending_presence.values() + ] + }, ) + ) - if pending_pdus: - logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", - destination, len(pending_pdus)) + if pending_pdus: + logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", + destination, len(pending_pdus)) - if not pending_pdus and not pending_edus and not pending_failures: - logger.debug("TX [%s] Nothing to send", destination) - self.last_device_stream_id_by_dest[destination] = ( - device_stream_id + if not pending_pdus and not pending_edus and not pending_failures: + logger.debug("TX [%s] Nothing to send", destination) + self.last_device_stream_id_by_dest[destination] = ( + device_stream_id + ) + return + + success = yield self._send_new_transaction( + destination, pending_pdus, pending_edus, pending_failures, + limiter=limiter, + ) + if success: + # Remove the acknowledged device messages from the database + # Only bother if we actually sent some device messages + if device_message_edus: + yield self.store.delete_device_msgs_for_remote( + destination, device_stream_id + ) + logger.info("Marking as sent %r %r", destination, dev_list_id) + yield self.store.mark_as_sent_devices_by_remote( + destination, dev_list_id ) - return - success = yield self._send_new_transaction( - destination, pending_pdus, pending_edus, pending_failures, - device_stream_id, - includes_device_messages=bool(device_message_edus), - limiter=limiter, - ) - if not success: - break + self.last_device_stream_id_by_dest[destination] = device_stream_id + self.last_device_list_stream_id_by_dest[destination] = dev_list_id + else: + break except NotRetryingDestination: logger.debug( "TX [%s] not ready for retry yet - " @@ -374,8 +386,6 @@ class TransactionQueue(object): @defer.inlineCallbacks def _get_new_device_messages(self, destination): - # TODO: Send appropriate device list messages - last_device_stream_id = self.last_device_stream_id_by_dest.get(destination, 0) to_device_stream_id = self.store.get_to_device_stream_token() contents, stream_id = yield self.store.get_new_device_msgs_for_remote( @@ -404,13 +414,12 @@ class TransactionQueue(object): ) for content in results ) - defer.returnValue((edus, stream_id)) + defer.returnValue((edus, stream_id, now_stream_id)) @measure_func("_send_new_transaction") @defer.inlineCallbacks def _send_new_transaction(self, destination, pending_pdus, pending_edus, - pending_failures, device_stream_id, - includes_device_messages, limiter): + pending_failures, limiter): # Sort based on the order field pending_pdus.sort(key=lambda t: t[1]) @@ -521,14 +530,6 @@ class TransactionQueue(object): "Failed to send event %s to %s", p.event_id, destination ) success = False - else: - # Remove the acknowledged device messages from the database - # Only bother if we actually sent some device messages - if includes_device_messages: - yield self.store.delete_device_msgs_for_remote( - destination, device_stream_id - ) - self.last_device_stream_id_by_dest[destination] = device_stream_id except RuntimeError as e: # We capture this here as there as nothing actually listens # for this finishing functions deferred. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index d92780b642..ba4c48d590 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -29,6 +29,7 @@ class DeviceHandler(BaseHandler): super(DeviceHandler, self).__init__(hs) self.state = hs.get_state_handler() + self.federation = hs.get_federation_sender() @defer.inlineCallbacks def check_device_registered(self, user_id, device_id, diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index b594f501f9..9628e2ff75 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -141,11 +141,11 @@ class DeviceStore(SQLBaseStore): def get_devices_by_remote(self, destination, from_stream_id): now_stream_id = self._device_list_id_gen.get_current_token() - has_changed = self._device_list_stream_cache.has_entity_changed( + has_changed = self._device_list_federation_stream_cache.has_entity_changed( destination, int(from_stream_id) ) if not has_changed: - defer.returnValue((now_stream_id, [])) + return (now_stream_id, []) return self.runInteraction( "get_devices_by_remote", self._get_devices_by_remote_txn, @@ -165,7 +165,7 @@ class DeviceStore(SQLBaseStore): rows = txn.fetchall() if not rows: - return now_stream_id, [] + return (now_stream_id, []) # maps (user_id, device_id) -> stream_id query_map = {(r[0], r[1]): r[2] for r in rows} @@ -189,7 +189,7 @@ class DeviceStore(SQLBaseStore): result = { "user_id": user_id, "device_id": device_id, - "prev_id": prev_id, + "prev_id": [prev_id] if prev_id else [], "stream_id": stream_id, } @@ -202,9 +202,9 @@ class DeviceStore(SQLBaseStore): if device_display_name: result["device_display_name"] = device_display_name - results.setdefault(user_id, {})[device_id] = result + results.append(result) - return now_stream_id, results + return (now_stream_id, results) def mark_as_sent_devices_by_remote(self, destination, stream_id): return self.runInteraction( @@ -212,19 +212,6 @@ class DeviceStore(SQLBaseStore): destination, stream_id, ) - @defer.inlineCallbacks - def get_user_whose_devices_changed(self, from_key): - from_key = int(from_key) - changed = self._device_list_stream_cache.get_all_entities_changed(from_key) - if changed is not None: - defer.returnValue(set(changed)) - - sql = """ - SELECT user_id FROM device_lists_stream WHERE stream_id > ? - """ - rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key) - defer.returnValue(set(row["user_id"] for row in rows)) - def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): sql = """ DELETE FROM device_lists_outbound_pokes @@ -239,7 +226,20 @@ class DeviceStore(SQLBaseStore): UPDATE device_lists_outbound_pokes SET sent = ? WHERE destination = ? AND stream_id <= ? """ - txn.execute(sql, (destination, True,)) + txn.execute(sql, (True, destination, stream_id,)) + + @defer.inlineCallbacks + def get_user_whose_devices_changed(self, from_key): + from_key = int(from_key) + changed = self._device_list_stream_cache.get_all_entities_changed(from_key) + if changed is not None: + defer.returnValue(set(changed)) + + sql = """ + SELECT user_id FROM device_lists_stream WHERE stream_id > ? + """ + rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key) + defer.returnValue(set(row["user_id"] for row in rows)) @defer.inlineCallbacks def add_device_change_to_streams(self, user_id, device_id, hosts): -- cgit 1.4.1 From c974116f197d211ba9b42159fe61cfd5957411b5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 26 Jan 2017 16:06:54 +0000 Subject: Implement device key caching over federation --- synapse/federation/federation_client.py | 10 + synapse/federation/federation_server.py | 3 + synapse/federation/transport/client.py | 26 +++ synapse/federation/transport/server.py | 8 + synapse/handlers/device.py | 85 +++++++-- synapse/handlers/e2e_keys.py | 40 +++- synapse/storage/devices.py | 201 +++++++++++++++++++-- synapse/storage/end_to_end_keys.py | 4 +- .../schema/delta/40/device_list_streams.sql | 20 +- tests/handlers/test_device.py | 18 +- tests/handlers/test_directory.py | 1 + tests/handlers/test_profile.py | 1 + tests/storage/test_appservice.py | 21 ++- 13 files changed, 381 insertions(+), 57 deletions(-) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index c9175bb33d..b5bcfd705a 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -126,6 +126,16 @@ class FederationClient(FederationBase): destination, content, timeout ) + @log_function + def query_user_devices(self, destination, user_id, timeout=30000): + """Query the device keys for a list of user ids hosted on a remote + server. + """ + sent_queries_counter.inc("user_devices") + return self.transport_layer.query_user_devices( + destination, user_id, timeout + ) + @log_function def claim_client_keys(self, destination, content, timeout): """Claims one-time keys for a device hosted on a remote server. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 862ccbef5d..e922b7ff4a 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -416,6 +416,9 @@ class FederationServer(FederationBase): def on_query_client_keys(self, origin, content): return self.on_query_request("client_keys", content) + def on_query_user_devices(self, origin, user_id): + return self.on_query_request("user_devices", user_id) + @defer.inlineCallbacks @log_function def on_claim_client_keys(self, origin, content): diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 915af34409..f49e8a2cc4 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -346,6 +346,32 @@ class TransportLayerClient(object): ) defer.returnValue(content) + @defer.inlineCallbacks + @log_function + def query_user_devices(self, destination, user_id, timeout): + """Query the devices for a user id hosted on a remote server. + + Response: + { + "stream_id": "...", + "devices": [ { ... } ] + } + + Args: + destination(str): The server to query. + query_content(dict): The user ids to query. + Returns: + A dict containg the device keys. + """ + path = PREFIX + "/user/devices/" + user_id + + content = yield self.client.get_json( + destination=destination, + path=path, + timeout=timeout, + ) + defer.returnValue(content) + @defer.inlineCallbacks @log_function def claim_client_keys(self, destination, query_content, timeout): diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 159dbd1747..c840da834c 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -409,6 +409,13 @@ class FederationClientKeysQueryServlet(BaseFederationServlet): return self.handler.on_query_client_keys(origin, content) +class FederationUserDevicesQueryServlet(BaseFederationServlet): + PATH = "/user/devices/(?P[^/]*)" + + def on_GET(self, origin, content, query, user_id): + return self.handler.on_query_user_devices(origin, user_id) + + class FederationClientKeysClaimServlet(BaseFederationServlet): PATH = "/user/keys/claim" @@ -613,6 +620,7 @@ SERVLET_CLASSES = ( FederationGetMissingEventsServlet, FederationEventAuthServlet, FederationClientKeysQueryServlet, + FederationUserDevicesQueryServlet, FederationClientKeysClaimServlet, FederationThirdPartyInviteExchangeServlet, On3pidBindServlet, diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index ba4c48d590..2d66b3721a 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -15,6 +15,7 @@ from synapse.api import errors from synapse.util import stringutils +from synapse.util.async import Linearizer from synapse.types import get_domain_from_id from twisted.internet import defer from ._base import BaseHandler @@ -28,8 +29,18 @@ class DeviceHandler(BaseHandler): def __init__(self, hs): super(DeviceHandler, self).__init__(hs) + self.hs = hs self.state = hs.get_state_handler() - self.federation = hs.get_federation_sender() + self.federation_sender = hs.get_federation_sender() + self.federation = hs.get_replication_layer() + self._remote_edue_linearizer = Linearizer(name="remote_device_list") + + self.federation.register_edu_handler( + "m.device_list_update", self._incoming_device_list_update, + ) + self.federation.register_query_handler( + "user_devices", self.on_federation_query_user_devices, + ) @defer.inlineCallbacks def check_device_registered(self, user_id, device_id, @@ -55,7 +66,7 @@ class DeviceHandler(BaseHandler): initial_device_display_name=initial_device_display_name, ) if new_device: - yield self.notify_device_update(user_id, device_id) + yield self.notify_device_update(user_id, [device_id]) defer.returnValue(device_id) # if the device id is not specified, we'll autogen one, but loop a few @@ -69,7 +80,7 @@ class DeviceHandler(BaseHandler): initial_device_display_name=initial_device_display_name, ) if new_device: - yield self.notify_device_update(user_id, device_id) + yield self.notify_device_update(user_id, [device_id]) defer.returnValue(device_id) attempts += 1 @@ -151,7 +162,7 @@ class DeviceHandler(BaseHandler): user_id=user_id, device_id=device_id ) - yield self.notify_device_update(user_id, device_id) + yield self.notify_device_update(user_id, [device_id]) @defer.inlineCallbacks def update_device(self, user_id, device_id, content): @@ -172,7 +183,7 @@ class DeviceHandler(BaseHandler): device_id, new_display_name=content.get("display_name") ) - yield self.notify_device_update(user_id, device_id) + yield self.notify_device_update(user_id, [device_id]) except errors.StoreError, e: if e.code == 404: raise errors.NotFoundError() @@ -180,26 +191,28 @@ class DeviceHandler(BaseHandler): raise @defer.inlineCallbacks - def notify_device_update(self, user_id, device_id): + def notify_device_update(self, user_id, device_ids): rooms = yield self.store.get_rooms_for_user(user_id) room_ids = [r.room_id for r in rooms] hosts = set() - for room_id in room_ids: - users = yield self.state.get_current_user_in_room(room_id) - hosts.update(get_domain_from_id(u) for u in users) - hosts.discard(self.server_name) + if self.hs.is_mine_id(user_id): + for room_id in room_ids: + users = yield self.state.get_current_user_in_room(room_id) + hosts.update(get_domain_from_id(u) for u in users) + hosts.discard(self.server_name) position = yield self.store.add_device_change_to_streams( - user_id, device_id, list(hosts) + user_id, device_ids, list(hosts) ) yield self.notifier.on_new_event( "device_list_key", position, rooms=room_ids, ) + logger.info("Sending device list update notif to: %r", hosts) for host in hosts: - self.federation.send_device_messages(host) + self.federation_sender.send_device_messages(host) @defer.inlineCallbacks def get_device_list_changes(self, user_id, room_ids, from_key): @@ -214,6 +227,54 @@ class DeviceHandler(BaseHandler): defer.returnValue(user_ids_changed) + @defer.inlineCallbacks + def _incoming_device_list_update(self, origin, edu_content): + user_id = edu_content["user_id"] + device_id = edu_content["device_id"] + stream_id = edu_content["stream_id"] + prev_ids = edu_content.get("prev_id", []) + + if get_domain_from_id(user_id) != origin: + # TODO: Raise? + return + + logger.info("Got edu: %r", edu_content) + + with (yield self._remote_edue_linearizer.queue(user_id)): + resync = True + if len(prev_ids) == 1: + extremity = yield self.store.get_device_list_remote_extremity(user_id) + logger.info("Extrem: %r, prev_ids: %r", extremity, prev_ids) + if str(extremity) == str(prev_ids[0]): + resync = False + + if resync: + result = yield self.federation.query_user_devices(origin, user_id) + stream_id = result["stream_id"] + devices = result["devices"] + yield self.store.update_remote_device_list_cache( + user_id, devices, stream_id, + ) + device_ids = [device["device_id"] for device in devices] + yield self.notify_device_update(user_id, device_ids) + else: + content = dict(edu_content) + for key in ("user_id", "device_id", "stream_id", "prev_ids"): + content.pop(key, None) + yield self.store.update_remote_device_list_cache_entry( + user_id, device_id, content, stream_id, + ) + yield self.notify_device_update(user_id, [device_id]) + + @defer.inlineCallbacks + def on_federation_query_user_devices(self, user_id): + stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) + defer.returnValue({ + "user_id": user_id, + "stream_id": stream_id, + "devices": devices, + }) + def _update_device_from_client_ips(device, client_ips): ip = client_ips.get((device["user_id"], device["device_id"]), {}) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 38c2a2d39e..832998a6d3 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -73,8 +73,7 @@ class E2eKeysHandler(object): if self.is_mine_id(user_id): local_query[user_id] = device_ids else: - domain = get_domain_from_id(user_id) - remote_queries.setdefault(domain, {})[user_id] = device_ids + remote_queries[user_id] = device_ids # do the queries failures = {} @@ -85,9 +84,40 @@ class E2eKeysHandler(object): if user_id in local_query: results[user_id] = keys + remote_queries_not_in_cache = {} + if remote_queries: + query_list = [] + for user_id, device_ids in remote_queries.iteritems(): + if device_ids: + query_list.extend((user_id, device_id) for device_id in device_ids) + else: + query_list.append((user_id, None)) + + user_ids_not_in_cache, remote_results = ( + yield self.store.get_user_devices_from_cache( + query_list + ) + ) + for user_id, devices in remote_results.iteritems(): + user_devices = results.setdefault(user_id, {}) + for device_id, device in devices.iteritems(): + keys = device.get("keys", None) + device_display_name = device.get("device_display_name", None) + if keys: + result = dict(keys) + unsigned = result.setdefault("unsigned", {}) + if device_display_name: + unsigned["device_display_name"] = device_display_name + user_devices[device_id] = result + + for user_id in user_ids_not_in_cache: + domain = get_domain_from_id(user_id) + r = remote_queries_not_in_cache.setdefault(domain, {}) + r[user_id] = remote_queries[user_id] + @defer.inlineCallbacks def do_remote_query(destination): - destination_query = remote_queries[destination] + destination_query = remote_queries_not_in_cache[destination] try: limiter = yield get_retry_limiter( destination, self.clock, self.store @@ -119,7 +149,7 @@ class E2eKeysHandler(object): yield preserve_context_over_deferred(defer.gatherResults([ preserve_fn(do_remote_query)(destination) - for destination in remote_queries + for destination in remote_queries_not_in_cache ])) defer.returnValue({ @@ -259,7 +289,7 @@ class E2eKeysHandler(object): user_id, device_id, time_now, encode_canonical_json(device_keys) ) - yield self.device_handler.notify_device_update(user_id, device_id) + yield self.device_handler.notify_device_update(user_id, [device_id]) one_time_keys = keys.get("one_time_keys", None) if one_time_keys: diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 9628e2ff75..8ee3119db2 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -138,6 +138,89 @@ class DeviceStore(SQLBaseStore): defer.returnValue({d["device_id"]: d for d in devices}) + def get_device_list_remote_extremity(self, user_id): + return self._simple_select_one_onecol( + table="device_lists_remote_extremeties", + keyvalues={"user_id": user_id}, + retcol="stream_id", + desc="get_device_list_remote_extremity", + allow_none=True, + ) + + def update_remote_device_list_cache_entry(self, user_id, device_id, content, + stream_id): + return self.runInteraction( + "update_remote_device_list_cache_entry", + self._update_remote_device_list_cache_entry_txn, + user_id, device_id, content, stream_id, + ) + + def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id, + content, stream_id): + self._simple_upsert_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + values={ + "content": json.dumps(content), + } + ) + + self._simple_upsert_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={ + "user_id": user_id, + }, + values={ + "stream_id": stream_id, + } + ) + + def update_remote_device_list_cache(self, user_id, devices, stream_id): + return self.runInteraction( + "update_remote_device_list_cache", + self._update_remote_device_list_cache_txn, + user_id, devices, stream_id, + ) + + def _update_remote_device_list_cache_txn(self, txn, user_id, devices, + stream_id): + self._simple_delete_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + }, + ) + + self._simple_insert_many_txn( + txn, + table="device_lists_remote_cache", + values=[ + { + "user_id": user_id, + "device_id": content["device_id"], + "content": json.dumps(content), + } + for content in devices + ] + ) + + self._simple_upsert_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={ + "user_id": user_id, + }, + values={ + "stream_id": stream_id, + } + ) + def get_devices_by_remote(self, destination, from_stream_id): now_stream_id = self._device_list_id_gen.get_current_token() @@ -184,7 +267,7 @@ class DeviceStore(SQLBaseStore): txn.execute(prev_sent_id_sql, (destination, user_id, True)) rows = txn.fetchall() prev_id = rows[0][0] - for device_id, result in user_devices.iteritems(): + for device_id, device in user_devices.iteritems(): stream_id = query_map[(user_id, device_id)] result = { "user_id": user_id, @@ -195,10 +278,10 @@ class DeviceStore(SQLBaseStore): prev_id = stream_id - key_json = result.get("key_json", None) + key_json = device.get("key_json", None) if key_json: result["keys"] = json.loads(key_json) - device_display_name = result.get("device_display_name", None) + device_display_name = device.get("device_display_name", None) if device_display_name: result["device_display_name"] = device_display_name @@ -206,6 +289,96 @@ class DeviceStore(SQLBaseStore): return (now_stream_id, results) + def get_user_devices_from_cache(self, query_list): + return self.runInteraction( + "get_user_devices_from_cache", self._get_user_devices_from_cache_txn, + query_list, + ) + + def _get_user_devices_from_cache_txn(self, txn, query_list): + user_ids = {user_id for user_id, _ in query_list} + + user_ids_in_cache = set() + for user_id in user_ids: + stream_ids = self._simple_select_onecol_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={ + "user_id": user_id, + }, + retcol="stream_id", + ) + if stream_ids: + user_ids_in_cache.add(user_id) + + user_ids_not_in_cache = user_ids - user_ids_in_cache + + results = {} + for user_id, device_id in query_list: + if user_id not in user_ids_in_cache: + continue + + if device_id: + content = self._simple_select_one_onecol_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + retcol="content", + ) + results.setdefault(user_id, {})[device_id] = json.loads(content) + else: + devices = self._simple_select_list_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + }, + retcols=("device_id", "content"), + ) + results[user_id] = { + device["device_id"]: json.loads(device["content"]) + for device in devices + } + user_ids_in_cache.discard(user_id) + + return user_ids_not_in_cache, results + + def get_devices_with_keys_by_user(self, user_id): + return self.runInteraction( + "get_devices_with_keys_by_user", + self._get_devices_with_keys_by_user_txn, user_id, + ) + + def _get_devices_with_keys_by_user_txn(self, txn, user_id): + now_stream_id = self._device_list_id_gen.get_current_token() + + devices = self._get_e2e_device_keys_txn( + txn, [(user_id, None)], include_all_devices=True + ) + + for user_id, user_devices in devices.iteritems(): + results = [] + for device_id, device in user_devices.iteritems(): + result = { + "device_id": device_id, + } + + key_json = device.get("key_json", None) + if key_json: + result["keys"] = json.loads(key_json) + device_display_name = device.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + + results.append(result) + + return now_stream_id, results + + return now_stream_id, [] + def mark_as_sent_devices_by_remote(self, destination, stream_id): return self.runInteraction( "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, @@ -242,17 +415,17 @@ class DeviceStore(SQLBaseStore): defer.returnValue(set(row["user_id"] for row in rows)) @defer.inlineCallbacks - def add_device_change_to_streams(self, user_id, device_id, hosts): + def add_device_change_to_streams(self, user_id, device_ids, hosts): # device_lists_stream # device_lists_outbound_pokes with self._device_list_id_gen.get_next() as stream_id: yield self.runInteraction( "add_device_change_to_streams", self._add_device_change_txn, - user_id, device_id, hosts, stream_id, + user_id, device_ids, hosts, stream_id, ) defer.returnValue(stream_id) - def _add_device_change_txn(self, txn, user_id, device_id, hosts, stream_id): + def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id): txn.call_after( self._device_list_stream_cache.entity_has_changed, user_id, stream_id, @@ -263,14 +436,17 @@ class DeviceStore(SQLBaseStore): host, stream_id, ) - self._simple_insert_txn( + self._simple_insert_many_txn( txn, table="device_lists_stream", - values={ - "stream_id": stream_id, - "user_id": user_id, - "device_id": device_id, - } + values=[ + { + "stream_id": stream_id, + "user_id": user_id, + "device_id": device_id, + } + for device_id in device_ids + ] ) self._simple_insert_many_txn( @@ -285,6 +461,7 @@ class DeviceStore(SQLBaseStore): "sent": False, } for destination in hosts + for device_id in device_ids ] ) diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index f82943a7a8..a915c790ff 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -52,11 +52,11 @@ class EndToEndKeyStore(SQLBaseStore): query_params = [] for (user_id, device_id) in query_list: - query_clause = "k.user_id = ?" + query_clause = "user_id = ?" query_params.append(user_id) if device_id: - query_clause += " AND k.device_id = ?" + query_clause += " AND device_id = ?" query_params.append(device_id) query_clauses.append(query_clause) diff --git a/synapse/storage/schema/delta/40/device_list_streams.sql b/synapse/storage/schema/delta/40/device_list_streams.sql index 61cac63bbb..d1051c6ddf 100644 --- a/synapse/storage/schema/delta/40/device_list_streams.sql +++ b/synapse/storage/schema/delta/40/device_list_streams.sql @@ -13,18 +13,6 @@ * limitations under the License. */ -CREATE TABLE device_list_streams_remote ( - list_id TEXT NOT NULL, - origin TEXT NOT NULL, - user_id TEXT NOT NULL, - is_full BOOLEAN NOT NULL, - ts BIGINT NOT NULL -); - -CREATE INDEX device_list_streams_remote_id_origin ON device_list_streams_remote( - origin, list_id, user_id -); - CREATE TABLE device_lists_remote_cache ( user_id TEXT NOT NULL, @@ -35,6 +23,14 @@ CREATE TABLE device_lists_remote_cache ( CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id); +CREATE TABLE device_lists_remote_extremeties ( + user_id TEXT NOT NULL, + stream_id TEXT NOT NULL +); + +CREATE INDEX device_lists_remote_extremeties_id ON device_lists_remote_extremeties(user_id, stream_id); + + CREATE TABLE device_lists_stream ( stream_id BIGINT NOT NULL, user_id TEXT NOT NULL, diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 85a970a6c9..2eaaa8253c 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -35,51 +35,51 @@ class DeviceTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield utils.setup_test_homeserver(handlers=None) - self.handler = synapse.handlers.device.DeviceHandler(hs) + hs = yield utils.setup_test_homeserver() + self.handler = hs.get_device_handler() self.store = hs.get_datastore() self.clock = hs.get_clock() @defer.inlineCallbacks def test_device_is_created_if_doesnt_exist(self): res = yield self.handler.check_device_registered( - user_id="boris", + user_id="@boris:foo", device_id="fco", initial_device_display_name="display name" ) self.assertEqual(res, "fco") - dev = yield self.handler.store.get_device("boris", "fco") + dev = yield self.handler.store.get_device("@boris:foo", "fco") self.assertEqual(dev["display_name"], "display name") @defer.inlineCallbacks def test_device_is_preserved_if_exists(self): res1 = yield self.handler.check_device_registered( - user_id="boris", + user_id="@boris:foo", device_id="fco", initial_device_display_name="display name" ) self.assertEqual(res1, "fco") res2 = yield self.handler.check_device_registered( - user_id="boris", + user_id="@boris:foo", device_id="fco", initial_device_display_name="new display name" ) self.assertEqual(res2, "fco") - dev = yield self.handler.store.get_device("boris", "fco") + dev = yield self.handler.store.get_device("@boris:foo", "fco") self.assertEqual(dev["display_name"], "display name") @defer.inlineCallbacks def test_device_id_is_made_up_if_unspecified(self): device_id = yield self.handler.check_device_registered( - user_id="theresa", + user_id="@theresa:foo", device_id=None, initial_device_display_name="display" ) - dev = yield self.handler.store.get_device("theresa", device_id) + dev = yield self.handler.store.get_device("@theresa:foo", device_id) self.assertEqual(dev["display_name"], "display") @defer.inlineCallbacks diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 5d602c1531..ceb9aa5765 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -37,6 +37,7 @@ class DirectoryTestCase(unittest.TestCase): def setUp(self): self.mock_federation = Mock(spec=[ "make_query", + "register_edu_handler", ]) self.query_handlers = {} diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index f1f664275f..979cebf600 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -39,6 +39,7 @@ class ProfileTestCase(unittest.TestCase): def setUp(self): self.mock_federation = Mock(spec=[ "make_query", + "register_edu_handler", ]) self.query_handlers = {} diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 9ff1abcd80..9e98d0e330 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -39,7 +39,11 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): event_cache_size=1, password_providers=[], ) - hs = yield setup_test_homeserver(config=config, federation_sender=Mock()) + hs = yield setup_test_homeserver( + config=config, + federation_sender=Mock(), + replication_layer=Mock(), + ) self.as_token = "token1" self.as_url = "some_url" @@ -112,7 +116,11 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): event_cache_size=1, password_providers=[], ) - hs = yield setup_test_homeserver(config=config, federation_sender=Mock()) + hs = yield setup_test_homeserver( + config=config, + federation_sender=Mock(), + replication_layer=Mock(), + ) self.db_pool = hs.get_db_pool() self.as_list = [ @@ -446,7 +454,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs = yield setup_test_homeserver( config=config, datastore=Mock(), - federation_sender=Mock() + federation_sender=Mock(), + replication_layer=Mock(), ) ApplicationServiceStore(hs) @@ -463,7 +472,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs = yield setup_test_homeserver( config=config, datastore=Mock(), - federation_sender=Mock() + federation_sender=Mock(), + replication_layer=Mock(), ) with self.assertRaises(ConfigError) as cm: @@ -486,7 +496,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs = yield setup_test_homeserver( config=config, datastore=Mock(), - federation_sender=Mock() + federation_sender=Mock(), + replication_layer=Mock(), ) with self.assertRaises(ConfigError) as cm: -- cgit 1.4.1 From fbfad76c03afe7538c67205ceb30825d9ce4fb07 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 26 Jan 2017 16:30:37 +0000 Subject: Add comments --- synapse/handlers/device.py | 19 +++++++++-- synapse/handlers/e2e_keys.py | 4 ++- synapse/storage/devices.py | 37 ++++++++++++++++++++-- .../schema/delta/40/device_list_streams.sql | 8 ++++- 4 files changed, 61 insertions(+), 7 deletions(-) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 2d66b3721a..a2ffd273bf 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -192,6 +192,9 @@ class DeviceHandler(BaseHandler): @defer.inlineCallbacks def notify_device_update(self, user_id, device_ids): + """Notify that a user's device(s) has changed. Pokes the notifier, and + remote servers if the user is local. + """ rooms = yield self.store.get_rooms_for_user(user_id) room_ids = [r.room_id for r in rooms] @@ -210,12 +213,16 @@ class DeviceHandler(BaseHandler): "device_list_key", position, rooms=room_ids, ) - logger.info("Sending device list update notif to: %r", hosts) - for host in hosts: - self.federation_sender.send_device_messages(host) + if hosts: + logger.info("Sending device list update notif to: %r", hosts) + for host in hosts: + self.federation_sender.send_device_messages(host) @defer.inlineCallbacks def get_device_list_changes(self, user_id, room_ids, from_key): + """For a user and their joined rooms, calculate which device updates + we need to return. + """ room_ids = frozenset(room_ids) user_ids_changed = set() @@ -236,11 +243,14 @@ class DeviceHandler(BaseHandler): if get_domain_from_id(user_id) != origin: # TODO: Raise? + logger.warning("Got device list update edu for %r from %r", user_id, origin) return logger.info("Got edu: %r", edu_content) with (yield self._remote_edue_linearizer.queue(user_id)): + # If the prev id matches whats in our cache table, then we don't need + # to resync the users device list, otherwise we do. resync = True if len(prev_ids) == 1: extremity = yield self.store.get_device_list_remote_extremity(user_id) @@ -249,6 +259,7 @@ class DeviceHandler(BaseHandler): resync = False if resync: + # Fetch all devices for the user. result = yield self.federation.query_user_devices(origin, user_id) stream_id = result["stream_id"] devices = result["devices"] @@ -258,6 +269,8 @@ class DeviceHandler(BaseHandler): device_ids = [device["device_id"] for device in devices] yield self.notify_device_update(user_id, device_ids) else: + # Simply update the single device, since we know that is the only + # change (becuase of the single prev_id matching the current cache) content = dict(edu_content) for key in ("user_id", "device_id", "stream_id", "prev_ids"): content.pop(key, None) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 832998a6d3..a16b9def8d 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -75,7 +75,7 @@ class E2eKeysHandler(object): else: remote_queries[user_id] = device_ids - # do the queries + # Firt get local devices. failures = {} results = {} if local_query: @@ -84,6 +84,7 @@ class E2eKeysHandler(object): if user_id in local_query: results[user_id] = keys + # Now attempt to get any remote devices from our local cache. remote_queries_not_in_cache = {} if remote_queries: query_list = [] @@ -115,6 +116,7 @@ class E2eKeysHandler(object): r = remote_queries_not_in_cache.setdefault(domain, {}) r[user_id] = remote_queries[user_id] + # Now fetch any devices that we don't have in our cache @defer.inlineCallbacks def do_remote_query(destination): destination_query = remote_queries_not_in_cache[destination] diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 8ee3119db2..cf38dbaa3c 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -139,6 +139,9 @@ class DeviceStore(SQLBaseStore): defer.returnValue({d["device_id"]: d for d in devices}) def get_device_list_remote_extremity(self, user_id): + """Get the last stream_id we got for a user. May be None if we haven't + got any information for them. + """ return self._simple_select_one_onecol( table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, @@ -149,6 +152,8 @@ class DeviceStore(SQLBaseStore): def update_remote_device_list_cache_entry(self, user_id, device_id, content, stream_id): + """Updates a single user's device in the cache. + """ return self.runInteraction( "update_remote_device_list_cache_entry", self._update_remote_device_list_cache_entry_txn, @@ -181,6 +186,8 @@ class DeviceStore(SQLBaseStore): ) def update_remote_device_list_cache(self, user_id, devices, stream_id): + """Replace the cache of the remote user's devices. + """ return self.runInteraction( "update_remote_device_list_cache", self._update_remote_device_list_cache_txn, @@ -222,6 +229,11 @@ class DeviceStore(SQLBaseStore): ) def get_devices_by_remote(self, destination, from_stream_id): + """Get stream of updates to send to remote servers + + Returns: + (now_stream_id, [ { updates }, .. ]) + """ now_stream_id = self._device_list_id_gen.get_current_token() has_changed = self._device_list_federation_stream_cache.has_entity_changed( @@ -290,6 +302,17 @@ class DeviceStore(SQLBaseStore): return (now_stream_id, results) def get_user_devices_from_cache(self, query_list): + """Get the devices (and keys if any) for remote users from the cache. + + Args: + query_list(list): List of (user_id, device_ids), if device_ids is + falsey then return all device ids for that user. + + Returns: + (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is + a set of user_ids and results_map is a mapping of + user_id -> device_id -> device_info + """ return self.runInteraction( "get_user_devices_from_cache", self._get_user_devices_from_cache_txn, query_list, @@ -347,6 +370,11 @@ class DeviceStore(SQLBaseStore): return user_ids_not_in_cache, results def get_devices_with_keys_by_user(self, user_id): + """Get all devices (with any device keys) for a user + + Returns: + (stream_id, devices) + """ return self.runInteraction( "get_devices_with_keys_by_user", self._get_devices_with_keys_by_user_txn, user_id, @@ -380,6 +408,8 @@ class DeviceStore(SQLBaseStore): return now_stream_id, [] def mark_as_sent_devices_by_remote(self, destination, stream_id): + """Mark that updates have successfully been sent to the destination. + """ return self.runInteraction( "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, destination, stream_id, @@ -403,6 +433,8 @@ class DeviceStore(SQLBaseStore): @defer.inlineCallbacks def get_user_whose_devices_changed(self, from_key): + """Get set of users whose devices have changed since `from_key`. + """ from_key = int(from_key) changed = self._device_list_stream_cache.get_all_entities_changed(from_key) if changed is not None: @@ -416,8 +448,9 @@ class DeviceStore(SQLBaseStore): @defer.inlineCallbacks def add_device_change_to_streams(self, user_id, device_ids, hosts): - # device_lists_stream - # device_lists_outbound_pokes + """Persist that a user's devices have been updated, and which hosts + (if any) should be poked. + """ with self._device_list_id_gen.get_next() as stream_id: yield self.runInteraction( "add_device_change_to_streams", self._add_device_change_txn, diff --git a/synapse/storage/schema/delta/40/device_list_streams.sql b/synapse/storage/schema/delta/40/device_list_streams.sql index d1051c6ddf..8348c143c3 100644 --- a/synapse/storage/schema/delta/40/device_list_streams.sql +++ b/synapse/storage/schema/delta/40/device_list_streams.sql @@ -13,7 +13,7 @@ * limitations under the License. */ - +-- Cache of remote devices. CREATE TABLE device_lists_remote_cache ( user_id TEXT NOT NULL, device_id TEXT NOT NULL, @@ -23,6 +23,8 @@ CREATE TABLE device_lists_remote_cache ( CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id); +-- The last update we got for a user. Empty if we're not receiving updates for +-- that user. CREATE TABLE device_lists_remote_extremeties ( user_id TEXT NOT NULL, stream_id TEXT NOT NULL @@ -31,6 +33,7 @@ CREATE TABLE device_lists_remote_extremeties ( CREATE INDEX device_lists_remote_extremeties_id ON device_lists_remote_extremeties(user_id, stream_id); +-- Stream of device lists updates. Includes both local and remotes CREATE TABLE device_lists_stream ( stream_id BIGINT NOT NULL, user_id TEXT NOT NULL, @@ -40,6 +43,9 @@ CREATE TABLE device_lists_stream ( CREATE INDEX device_lists_stream_id ON device_lists_stream(stream_id, user_id); +-- The stream of updates to send to other servers. We keep at least one row +-- per user that was sent so that the prev_id for any new updates can be +-- calculated CREATE TABLE device_lists_outbound_pokes ( destination TEXT NOT NULL, stream_id BIGINT NOT NULL, -- cgit 1.4.1 From 76d40f490411ce1a0a208acb4242678b0cb4afb3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 26 Jan 2017 16:39:33 +0000 Subject: Handle users leaving rooms --- synapse/handlers/device.py | 17 ++++++++++++++++- synapse/storage/devices.py | 8 ++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index a2ffd273bf..1116dfd27c 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -42,6 +42,8 @@ class DeviceHandler(BaseHandler): "user_devices", self.on_federation_query_user_devices, ) + hs.get_distributor().observe("user_left_room", self.user_left_room) + @defer.inlineCallbacks def check_device_registered(self, user_id, device_id, initial_device_display_name=None): @@ -246,7 +248,11 @@ class DeviceHandler(BaseHandler): logger.warning("Got device list update edu for %r from %r", user_id, origin) return - logger.info("Got edu: %r", edu_content) + rooms = yield self.store.get_rooms_for_user(user_id) + if not rooms: + # We don't share any rooms with this user. Ignore update, as we + # probably won't get any further updates. + return with (yield self._remote_edue_linearizer.queue(user_id)): # If the prev id matches whats in our cache table, then we don't need @@ -288,6 +294,15 @@ class DeviceHandler(BaseHandler): "devices": devices, }) + @defer.inlineCallbacks + def user_left_room(self, user, room_id): + user_id = user.to_string() + rooms = yield self.store.get_rooms_for_user(user_id) + if not rooms: + # We no longer share rooms with this user, so we'll no longer + # receive device updates. Mark this in DB. + yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id) + def _update_device_from_client_ips(device, client_ips): ip = client_ips.get((device["user_id"], device["device_id"]), {}) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index cf38dbaa3c..1c48c3af99 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -150,6 +150,14 @@ class DeviceStore(SQLBaseStore): allow_none=True, ) + def mark_remote_user_device_list_as_unsubscribed(self, user_id): + return self._simple_delete( + table="device_lists_remote_extremeties", + keyvalues={ + "user_id": user_id, + }, + ) + def update_remote_device_list_cache_entry(self, user_id, device_id, content, stream_id): """Updates a single user's device in the cache. -- cgit 1.4.1 From 31aca5589c3790201b2087e28d2901d00e1f77d5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 26 Jan 2017 16:55:50 +0000 Subject: Fix on sqlite: use left rather than outer join --- synapse/storage/end_to_end_keys.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index a915c790ff..441286d1a1 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -65,11 +65,11 @@ class EndToEndKeyStore(SQLBaseStore): "SELECT user_id, device_id, " " d.display_name AS device_display_name, " " k.key_json" - " FROM e2e_device_keys_json k" - " %s JOIN devices d USING (user_id, device_id)" + " FROM devices d" + " %s JOIN e2e_device_keys_json k USING (user_id, device_id)" " WHERE %s" ) % ( - "FULL OUTER" if include_all_devices else "LEFT", + "LEFT" if include_all_devices else "INNER", " OR ".join("(" + q + ")" for q in query_clauses) ) -- cgit 1.4.1 From b3e1f2aa7a1d583119378bb938ad476e72cc35ac Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 26 Jan 2017 17:16:24 +0000 Subject: Fix unit tests --- tests/storage/test_end_to_end_keys.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 453bc61438..bfa6294250 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -35,6 +35,10 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): now = 1470174257070 json = '{ "key": "value" }' + yield self.store.store_device( + "user", "device", None + ) + yield self.store.set_e2e_device_keys( "user", "device", now, json) @@ -71,6 +75,19 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): def test_multiple_devices(self): now = 1470174257070 + yield self.store.store_device( + "user1", "device1", None + ) + yield self.store.store_device( + "user1", "device2", None + ) + yield self.store.store_device( + "user2", "device1", None + ) + yield self.store.store_device( + "user2", "device2", None + ) + yield self.store.set_e2e_device_keys( "user1", "device1", now, 'json11') yield self.store.set_e2e_device_keys( -- cgit 1.4.1 From f25a4a4692d9b4618efb64984c10a6e8243a4a0b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 27 Jan 2017 10:27:39 +0000 Subject: Remove unused param --- synapse/storage/devices.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 1c48c3af99..b99de2f1b0 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -26,8 +26,7 @@ logger = logging.getLogger(__name__) class DeviceStore(SQLBaseStore): @defer.inlineCallbacks def store_device(self, user_id, device_id, - initial_device_display_name, - ignore_if_known=True): + initial_device_display_name): """Ensure the given device is known; add it to the store if not Args: -- cgit 1.4.1 From 888c59c955b33c3c69a73766507e134d64a8f25b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 27 Jan 2017 10:29:47 +0000 Subject: Better name --- synapse/handlers/device.py | 4 +++- synapse/storage/devices.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 1116dfd27c..ed077c9a76 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -259,7 +259,9 @@ class DeviceHandler(BaseHandler): # to resync the users device list, otherwise we do. resync = True if len(prev_ids) == 1: - extremity = yield self.store.get_device_list_remote_extremity(user_id) + extremity = yield self.store.get_device_list_last_stream_id_for_remote( + user_id + ) logger.info("Extrem: %r, prev_ids: %r", extremity, prev_ids) if str(extremity) == str(prev_ids[0]): resync = False diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index b99de2f1b0..d46203dd35 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -137,7 +137,7 @@ class DeviceStore(SQLBaseStore): defer.returnValue({d["device_id"]: d for d in devices}) - def get_device_list_remote_extremity(self, user_id): + def get_device_list_last_stream_id_for_remote(self, user_id): """Get the last stream_id we got for a user. May be None if we haven't got any information for them. """ -- cgit 1.4.1 From 755adff0e407abd48bd30b544e7025da1381f3d2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 27 Jan 2017 10:31:06 +0000 Subject: User if rather than for --- synapse/storage/devices.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index d46203dd35..ad5411dbe7 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -150,6 +150,8 @@ class DeviceStore(SQLBaseStore): ) def mark_remote_user_device_list_as_unsubscribed(self, user_id): + """Mark that we no longer track device lists for remote user. + """ return self._simple_delete( table="device_lists_remote_extremeties", keyvalues={ @@ -394,7 +396,8 @@ class DeviceStore(SQLBaseStore): txn, [(user_id, None)], include_all_devices=True ) - for user_id, user_devices in devices.iteritems(): + if devices: + user_devices = devices[user_id] results = [] for device_id, device in user_devices.iteritems(): result = { -- cgit 1.4.1 From 738a2867c8d6e8c97b956b6a58c7373a49e60ddb Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 27 Jan 2017 10:31:29 +0000 Subject: SQL param ordering --- synapse/storage/devices.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index ad5411dbe7..918520269e 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -260,7 +260,7 @@ class DeviceStore(SQLBaseStore): now_stream_id): sql = """ SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes - WHERE destination = ? AND stream_id > ? AND stream_id <= ? AND sent = ? + WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? GROUP BY user_id, device_id """ txn.execute( -- cgit 1.4.1 From c517a19c2ddc042e5a87dfaaecc55a790f62ed71 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 27 Jan 2017 10:33:26 +0000 Subject: Comment --- synapse/storage/end_to_end_keys.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 441286d1a1..85763f7ceb 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -35,6 +35,8 @@ class EndToEndKeyStore(SQLBaseStore): """Fetch a list of device keys. Args: query_list(list): List of pairs of user_ids and device_ids. + include_all_devices (bool): whether to include entries for devices + that don't have device keys Returns: Dict mapping from user-id to dict mapping from device_id to dict containing "key_json", "device_display_name". -- cgit 1.4.1 From 84a35f32c72693e2fd98677dc4e26e14ca8d56c5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 27 Jan 2017 10:35:12 +0000 Subject: Comment --- synapse/storage/devices.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 918520269e..00317b0c1f 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -285,6 +285,8 @@ class DeviceStore(SQLBaseStore): results = [] for user_id, user_devices in devices.iteritems(): + # We bind literal True, as its database dependent how booleans are + # handled. txn.execute(prev_sent_id_sql, (destination, user_id, True)) rows = txn.fetchall() prev_id = rows[0][0] -- cgit 1.4.1 From 252b503fc8626078141dc6b82eeff63607874347 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 27 Jan 2017 13:36:39 +0000 Subject: Hook device list updates to replication --- synapse/app/federation_sender.py | 3 +- synapse/app/synchrotron.py | 27 ++++++++++- synapse/handlers/device.py | 16 ------- synapse/handlers/sync.py | 35 ++++++++++---- synapse/replication/resource.py | 20 +++++++- synapse/replication/slave/storage/devices.py | 72 ++++++++++++++++++++++++++++ synapse/storage/devices.py | 15 ++++++ 7 files changed, 159 insertions(+), 29 deletions(-) create mode 100644 synapse/replication/slave/storage/devices.py diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index ec06620efb..411e47d98d 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -30,6 +30,7 @@ from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.transactions import TransactionStore +from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.storage.engines import create_engine from synapse.storage.presence import UserPresenceState from synapse.util.async import sleep @@ -56,7 +57,7 @@ logger = logging.getLogger("synapse.app.appservice") class FederationSenderSlaveStore( SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore, - SlavedRegistrationStore, + SlavedRegistrationStore, SlavedDeviceStore, ): pass diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 4dfc2dc648..9d250502e0 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -39,6 +39,7 @@ from synapse.replication.slave.storage.filtering import SlavedFilteringStore from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.presence import SlavedPresenceStore from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore +from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.room import RoomStore from synapse.server import HomeServer from synapse.storage.client_ips import ClientIpStore @@ -77,6 +78,7 @@ class SynchrotronSlavedStore( SlavedFilteringStore, SlavedPresenceStore, SlavedDeviceInboxStore, + SlavedDeviceStore, RoomStore, BaseSlavedStore, ClientIpStore, # After BaseSlavedStore because the constructor is different @@ -380,6 +382,28 @@ class SynchrotronServer(HomeServer): stream_key, position, users=users, rooms=rooms ) + @defer.inlineCallbacks + def notify_device_list_update(result): + stream = result.get("device_lists") + if not stream: + return + + position_index = stream["field_names"].index("position") + user_index = stream["field_names"].index("user_id") + + for row in stream["rows"]: + logger.info("Handling device list row: %r", row) + position = row[position_index] + user_id = row[user_index] + + rooms = yield store.get_rooms_for_user(user_id) + room_ids = [r.room_id for r in rooms] + + notifier.on_new_event( + "device_list_key", position, rooms=room_ids, + ) + + @defer.inlineCallbacks def notify(result): stream = result.get("events") if stream: @@ -417,6 +441,7 @@ class SynchrotronServer(HomeServer): notify_from_stream( result, "to_device", "to_device_key", user="user_id" ) + yield notify_device_list_update(result) while True: try: @@ -427,7 +452,7 @@ class SynchrotronServer(HomeServer): yield store.process_replication(result) typing_handler.process_replication(result) yield presence_handler.process_replication(result) - notify(result) + yield notify(result) except: logger.exception("Error replicating from %r", replication_url) yield sleep(5) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index ed077c9a76..6fefb85890 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -220,22 +220,6 @@ class DeviceHandler(BaseHandler): for host in hosts: self.federation_sender.send_device_messages(host) - @defer.inlineCallbacks - def get_device_list_changes(self, user_id, room_ids, from_key): - """For a user and their joined rooms, calculate which device updates - we need to return. - """ - room_ids = frozenset(room_ids) - - user_ids_changed = set() - changed = yield self.store.get_user_whose_devices_changed(from_key) - for other_user_id in changed: - other_rooms = yield self.store.get_rooms_for_user(other_user_id) - if room_ids.intersection(e.room_id for e in other_rooms): - user_ids_changed.add(other_user_id) - - defer.returnValue(user_ids_changed) - @defer.inlineCallbacks def _incoming_device_list_update(self, origin, edu_content): user_id = edu_content["user_id"] diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 06bf626367..9199f20817 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -144,7 +144,6 @@ class SyncHandler(object): self.clock = hs.get_clock() self.response_cache = ResponseCache(hs) self.state = hs.get_state_handler() - self.device_handler = hs.get_device_handler() def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, full_state=False): @@ -546,15 +545,9 @@ class SyncHandler(object): yield self._generate_sync_entry_for_to_device(sync_result_builder) - if since_token and since_token.device_list_key: - user_id = sync_config.user.to_string() - rooms = yield self.store.get_rooms_for_user(user_id) - joined_room_ids = set(r.room_id for r in rooms) - device_lists = yield self.device_handler.get_device_list_changes( - user_id, joined_room_ids, since_token.device_list_key - ) - else: - device_lists = [] + device_lists = yield self._generate_sync_entry_for_device_list( + sync_result_builder + ) defer.returnValue(SyncResult( presence=sync_result_builder.presence, @@ -567,6 +560,28 @@ class SyncHandler(object): next_batch=sync_result_builder.now_token, )) + @defer.inlineCallbacks + def _generate_sync_entry_for_device_list(self, sync_result_builder): + user_id = sync_result_builder.sync_config.user.to_string() + since_token = sync_result_builder.since_token + + if since_token and since_token.device_list_key: + rooms = yield self.store.get_rooms_for_user(user_id) + room_ids = set(r.room_id for r in rooms) + + user_ids_changed = set() + changed = yield self.store.get_user_whose_devices_changed( + since_token.device_list_key + ) + for other_user_id in changed: + other_rooms = yield self.store.get_rooms_for_user(other_user_id) + if room_ids.intersection(e.room_id for e in other_rooms): + user_ids_changed.add(other_user_id) + + defer.returnValue(user_ids_changed) + else: + defer.returnValue([]) + @defer.inlineCallbacks def _generate_sync_entry_for_to_device(self, sync_result_builder): """Generates the portion of the sync response. Populates diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 4616e9b34a..36548c5eda 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -46,6 +46,7 @@ STREAM_NAMES = ( ("to_device",), ("public_rooms",), ("federation",), + ("device_lists",), ) @@ -140,6 +141,7 @@ class ReplicationResource(Resource): caches_token = self.store.get_cache_stream_token() public_rooms_token = self.store.get_current_public_room_stream_id() federation_token = self.federation_sender.get_current_token() + device_list_token = self.store.get_device_stream_token() defer.returnValue(_ReplicationToken( room_stream_token, @@ -155,6 +157,7 @@ class ReplicationResource(Resource): int(stream_token.to_device_key), int(public_rooms_token), int(federation_token), + int(device_list_token), )) @request_handler() @@ -214,6 +217,7 @@ class ReplicationResource(Resource): yield self.caches(writer, current_token, limit, request_streams) yield self.to_device(writer, current_token, limit, request_streams) yield self.public_rooms(writer, current_token, limit, request_streams) + yield self.device_lists(writer, current_token, limit, request_streams) self.federation(writer, current_token, limit, request_streams, federation_ack) self.streams(writer, current_token, request_streams) @@ -495,6 +499,20 @@ class ReplicationResource(Resource): "position", "type", "content", ), position=upto_token) + @defer.inlineCallbacks + def device_lists(self, writer, current_token, limit, request_streams): + current_position = current_token.device_lists + + device_lists = request_streams.get("device_lists") + + if device_lists is not None and device_lists != current_position: + changes = yield self.store.get_users_and_hosts_device_list_changes( + device_lists, + ) + writer.write_header_and_rows("device_lists", changes, ( + "position", "user_id", "destination", + ), position=current_position) + class _Writer(object): """Writes the streams as a JSON object as the response to the request""" @@ -527,7 +545,7 @@ class _Writer(object): class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( "events", "presence", "typing", "receipts", "account_data", "backfill", "push_rules", "pushers", "state", "caches", "to_device", "public_rooms", - "federation", + "federation", "device_lists", ))): __slots__ = [] diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py new file mode 100644 index 0000000000..ca46aa17b6 --- /dev/null +++ b/synapse/replication/slave/storage/devices.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker +from synapse.storage import DataStore +from synapse.util.caches.stream_change_cache import StreamChangeCache + + +class SlavedDeviceStore(BaseSlavedStore): + def __init__(self, db_conn, hs): + super(SlavedDeviceStore, self).__init__(db_conn, hs) + + self.hs = hs + + self._device_list_id_gen = SlavedIdTracker( + db_conn, "device_lists_stream", "stream_id", + ) + device_list_max = self._device_list_id_gen.get_current_token() + self._device_list_stream_cache = StreamChangeCache( + "DeviceListStreamChangeCache", device_list_max, + ) + self._device_list_federation_stream_cache = StreamChangeCache( + "DeviceListFederationStreamChangeCache", device_list_max, + ) + + get_device_stream_token = DataStore.get_device_stream_token.__func__ + get_user_whose_devices_changed = DataStore.get_user_whose_devices_changed.__func__ + get_devices_by_remote = DataStore.get_devices_by_remote.__func__ + _get_devices_by_remote_txn = DataStore._get_devices_by_remote_txn.__func__ + _get_e2e_device_keys_txn = DataStore._get_e2e_device_keys_txn.__func__ + mark_as_sent_devices_by_remote = DataStore.mark_as_sent_devices_by_remote.__func__ + _mark_as_sent_devices_by_remote_txn = ( + DataStore._mark_as_sent_devices_by_remote_txn.__func__ + ) + + def stream_positions(self): + result = super(SlavedDeviceStore, self).stream_positions() + result["device_lists"] = self._device_list_id_gen.get_current_token() + return result + + def process_replication(self, result): + stream = result.get("device_lists") + if stream: + self._device_list_id_gen.advance(int(stream["position"])) + for row in stream["rows"]: + stream_id = row[0] + user_id = row[1] + destination = row[2] + + self._device_list_stream_cache.entity_has_changed( + user_id, stream_id + ) + + if destination: + self._device_list_federation_stream_cache.entity_has_changed( + destination, stream_id + ) + + return super(SlavedDeviceStore, self).process_replication(result) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 00317b0c1f..2b2cebacfa 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -458,6 +458,21 @@ class DeviceStore(SQLBaseStore): rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key) defer.returnValue(set(row["user_id"] for row in rows)) + def get_users_and_hosts_device_list_changes(self, from_key): + """Return a list of `(stream_id, user_id, destination)` which is the + combined list of changes to devices, and which destinations need to be + poked. `destination` may be None if no destinations need to be poked. + """ + sql = """ + SELECT stream_id, user_id, destination FROM device_lists_stream + LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id) + WHERE stream_id > ? + """ + return self._execute( + "get_users_and_hosts_device_list", None, + sql, from_key, + ) + @defer.inlineCallbacks def add_device_change_to_streams(self, user_id, device_ids, hosts): """Persist that a user's devices have been updated, and which hosts -- cgit 1.4.1 From d1e1fd62108c9b285d7c57d357311c6d5df2190e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 27 Jan 2017 15:23:48 +0000 Subject: Add ts column to device_lists_outbound_pokes --- synapse/storage/devices.py | 3 +++ synapse/storage/schema/delta/40/device_list_streams.sql | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 2b2cebacfa..89c7bc0cc0 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -486,6 +486,8 @@ class DeviceStore(SQLBaseStore): defer.returnValue(stream_id) def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id): + now = self._clock.time_msec() + txn.call_after( self._device_list_stream_cache.entity_has_changed, user_id, stream_id, @@ -519,6 +521,7 @@ class DeviceStore(SQLBaseStore): "user_id": user_id, "device_id": device_id, "sent": False, + "ts": now, } for destination in hosts for device_id in device_ids diff --git a/synapse/storage/schema/delta/40/device_list_streams.sql b/synapse/storage/schema/delta/40/device_list_streams.sql index 8348c143c3..54841b3843 100644 --- a/synapse/storage/schema/delta/40/device_list_streams.sql +++ b/synapse/storage/schema/delta/40/device_list_streams.sql @@ -51,7 +51,8 @@ CREATE TABLE device_lists_outbound_pokes ( stream_id BIGINT NOT NULL, user_id TEXT NOT NULL, device_id TEXT NOT NULL, - sent BOOLEAN NOT NULL + sent BOOLEAN NOT NULL, + ts BIGINT NOT NULL -- So that in future we can clear out pokes to dead servers ); CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes(destination, stream_id); -- cgit 1.4.1 From 76100203aba979e21f3831e3a675ae4e3d578ad4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 30 Jan 2017 10:11:46 +0000 Subject: Always use the latest stream_id, sent or unsent --- synapse/storage/devices.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 89c7bc0cc0..d72f60d94b 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -280,14 +280,14 @@ class DeviceStore(SQLBaseStore): prev_sent_id_sql = """ SELECT coalesce(max(stream_id), 0) as stream_id FROM device_lists_outbound_pokes - WHERE destination = ? AND user_id = ? AND sent = ? + WHERE destination = ? AND user_id = ? AND stream_id <= ? """ results = [] for user_id, user_devices in devices.iteritems(): - # We bind literal True, as its database dependent how booleans are - # handled. - txn.execute(prev_sent_id_sql, (destination, user_id, True)) + # The prev_id for the first row is always the last row before + # `from_stream_id` + txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id)) rows = txn.fetchall() prev_id = rows[0][0] for device_id, device in user_devices.iteritems(): -- cgit 1.4.1 From d360c97ae1e79789e97ab6a12e005d22334e416f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 30 Jan 2017 10:12:00 +0000 Subject: Clear out old destination pokes. --- synapse/storage/devices.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index d72f60d94b..c05ca7c5e0 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -24,6 +24,13 @@ logger = logging.getLogger(__name__) class DeviceStore(SQLBaseStore): + def __init__(self, hs): + super(DeviceStore, self).__init__(hs) + + self._clock.looping_call( + self._prune_old_outbound_device_pokes, 60 * 60 * 1000 + ) + @defer.inlineCallbacks def store_device(self, user_id, device_id, initial_device_display_name): @@ -530,3 +537,38 @@ class DeviceStore(SQLBaseStore): def get_device_stream_token(self): return self._device_list_id_gen.get_current_token() + + def _prune_old_outbound_device_pokes(self): + """Delete old entries out of the device_lists_outbound_pokes to ensure + that we don't fill up due to dead servers. We keep one entry per + (destination, user_id) tuple to ensure that the prev_ids remain correct + if the server does come back. + """ + now = self._clock.time_msec() + + def _prune_txn(txn): + select_sql = """ + SELECT destination, user_id, max(stream_id) as stream_id + FROM device_lists_outbound_pokes + GROUP BY destination, user_id + """ + + txn.execute(select_sql) + rows = txn.fetchall() + + delete_sql = """ + DELETE FROM device_lists_outbound_pokes + WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ? + """ + + txn.executemany( + delete_sql, + ( + (now, row["destination"], row["user_id"], row["stream_id"]) + for row in rows + ) + ) + + return self.runInteraction( + "_prune_old_outbound_device_pokes", _prune_txn + ) -- cgit 1.4.1 From 4ac363a16886b05cf15064932b6510cdff729c57 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 30 Jan 2017 14:10:12 +0000 Subject: Remove debug logging --- synapse/app/synchrotron.py | 1 - 1 file changed, 1 deletion(-) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 9d250502e0..b3fb408cfd 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -392,7 +392,6 @@ class SynchrotronServer(HomeServer): user_index = stream["field_names"].index("user_id") for row in stream["rows"]: - logger.info("Handling device list row: %r", row) position = row[position_index] user_id = row[user_index] -- cgit 1.4.1 From 3670025e641e338de14a012a24d0ceb1cade194c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 30 Jan 2017 14:11:31 +0000 Subject: Rename func --- synapse/replication/resource.py | 2 +- synapse/storage/devices.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 36548c5eda..a30e647474 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -506,7 +506,7 @@ class ReplicationResource(Resource): device_lists = request_streams.get("device_lists") if device_lists is not None and device_lists != current_position: - changes = yield self.store.get_users_and_hosts_device_list_changes( + changes = yield self.store.get_all_device_list_changes_for_remotes( device_lists, ) writer.write_header_and_rows("device_lists", changes, ( diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index c05ca7c5e0..e68ee50152 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -465,7 +465,7 @@ class DeviceStore(SQLBaseStore): rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key) defer.returnValue(set(row["user_id"] for row in rows)) - def get_users_and_hosts_device_list_changes(self, from_key): + def get_all_device_list_changes_for_remotes(self, from_key): """Return a list of `(stream_id, user_id, destination)` which is the combined list of changes to devices, and which destinations need to be poked. `destination` may be None if no destinations need to be poked. -- cgit 1.4.1 From 828db669ecf4a631a4cab0c78d2ea8df7c531716 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 30 Jan 2017 16:37:22 +0000 Subject: Use get_users_in_room and declare it iterable --- synapse/handlers/device.py | 2 +- synapse/storage/roommember.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 6fefb85890..7245d14fab 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -203,7 +203,7 @@ class DeviceHandler(BaseHandler): hosts = set() if self.hs.is_mine_id(user_id): for room_id in room_ids: - users = yield self.state.get_current_user_in_room(room_id) + users = yield self.store.get_users_in_room(room_id) hosts.update(get_domain_from_id(u) for u in users) hosts.discard(self.server_name) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 768e0a4451..6cf1a538ae 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -131,7 +131,7 @@ class RoomMemberStore(SQLBaseStore): with self._stream_id_gen.get_next() as stream_ordering: yield self.runInteraction("locally_reject_invite", f, stream_ordering) - @cached(max_entries=5000) + @cached(max_entries=1000000, iterable=True) def get_users_in_room(self, room_id): def f(txn): -- cgit 1.4.1 From e75a779d9e1e57353ea3e718efd72a88ce8e71e3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 30 Jan 2017 16:38:20 +0000 Subject: Fix query --- synapse/storage/devices.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index e68ee50152..3a6d2cbcd6 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -463,7 +463,7 @@ class DeviceStore(SQLBaseStore): SELECT user_id FROM device_lists_stream WHERE stream_id > ? """ rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key) - defer.returnValue(set(row["user_id"] for row in rows)) + defer.returnValue(set(row[0] for row in rows)) def get_all_device_list_changes_for_remotes(self, from_key): """Return a list of `(stream_id, user_id, destination)` which is the -- cgit 1.4.1 From c2c9a78db9393bafed59023298e71ab2c9fc8ae7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 30 Jan 2017 16:55:04 +0000 Subject: Noop device key changes if they're the same --- synapse/handlers/e2e_keys.py | 9 ++++--- synapse/storage/devices.py | 1 + synapse/storage/end_to_end_keys.py | 50 +++++++++++++++++++++++++++++--------- 3 files changed, 45 insertions(+), 15 deletions(-) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index a16b9def8d..49b277a1af 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -287,11 +287,12 @@ class E2eKeysHandler(object): device_id, user_id, time_now ) # TODO: Sign the JSON with the server key - yield self.store.set_e2e_device_keys( - user_id, device_id, time_now, - encode_canonical_json(device_keys) + changed = yield self.store.set_e2e_device_keys( + user_id, device_id, time_now, device_keys, ) - yield self.device_handler.notify_device_update(user_id, [device_id]) + if changed: + # Only notify about device updates *if* the keys actually changed + yield self.device_handler.notify_device_update(user_id, [device_id]) one_time_keys = keys.get("one_time_keys", None) if one_time_keys: diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 3a6d2cbcd6..f0353929da 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -164,6 +164,7 @@ class DeviceStore(SQLBaseStore): keyvalues={ "user_id": user_id, }, + desc="mark_remote_user_device_list_as_unsubscribed", ) def update_remote_device_list_cache_entry(self, user_id, device_id, content, diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 85763f7ceb..aa54d7637c 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -14,21 +14,49 @@ # limitations under the License. from twisted.internet import defer +from canonicaljson import encode_canonical_json + from ._base import SQLBaseStore class EndToEndKeyStore(SQLBaseStore): - def set_e2e_device_keys(self, user_id, device_id, time_now, json_bytes): - return self._simple_upsert( - table="e2e_device_keys_json", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - }, - values={ - "ts_added_ms": time_now, - "key_json": json_bytes, - } + def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): + """Stores device keys for a device. Returns whether there was a change + or the keys were already in the database. + """ + def _set_e2e_device_keys_txn(txn): + old_key_json = self._simple_select_one_onecol_txn( + txn, + table="e2e_device_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + retcol="key_json", + allow_none=True, + ) + + new_key_json = encode_canonical_json(device_keys) + if old_key_json == new_key_json: + return False + + self._simple_upsert_txn( + txn, + table="e2e_device_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + values={ + "ts_added_ms": time_now, + "key_json": new_key_json, + } + ) + + return True + + return self.runInteraction( + "set_e2e_device_keys", _set_e2e_device_keys_txn ) def get_e2e_device_keys(self, query_list, include_all_devices=False): -- cgit 1.4.1 From fd1c18c0887321934f89e38ab9d62b677128fffb Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 30 Jan 2017 17:00:24 +0000 Subject: Use DB cache of joined users for presence --- synapse/handlers/presence.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 1b89dc6274..9982ae0fed 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -574,7 +574,7 @@ class PresenceHandler(object): if not local_states: continue - users = yield self.state.get_current_user_in_room(room_id) + users = yield self.store.get_users_in_room(room_id) hosts = set(get_domain_from_id(u) for u in users) for host in hosts: @@ -766,7 +766,7 @@ class PresenceHandler(object): # don't need to send to local clients here, as that is done as part # of the event stream/sync. # TODO: Only send to servers not already in the room. - user_ids = yield self.state.get_current_user_in_room(room_id) + user_ids = yield self.store.get_users_in_room(room_id) if self.is_mine(user): state = yield self.current_state_for_user(user.to_string()) @@ -1069,7 +1069,7 @@ class PresenceEventSource(object): user_ids_to_check = set() for room_id in room_ids: - users = yield self.state.get_current_user_in_room(room_id) + users = yield self.store.get_users_in_room(room_id) user_ids_to_check.update(users) user_ids_to_check.update(friends) -- cgit 1.4.1 From c7a26b7c3243c3187a8d12060cb2d2a02d318260 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 30 Jan 2017 17:11:24 +0000 Subject: Fix unit tests --- synapse/handlers/e2e_keys.py | 2 +- synapse/storage/end_to_end_keys.py | 12 ++++++++++-- tests/storage/test_end_to_end_keys.py | 8 ++++---- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 49b277a1af..e40495d1ab 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -194,7 +194,7 @@ class E2eKeysHandler(object): # "unsigned" section for user_id, device_keys in results.items(): for device_id, device_info in device_keys.items(): - r = json.loads(device_info["key_json"]) + r = dict(device_info["keys"]) r["unsigned"] = {} display_name = device_info["device_display_name"] if display_name is not None: diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index aa54d7637c..2040e022fa 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -15,6 +15,7 @@ from twisted.internet import defer from canonicaljson import encode_canonical_json +import ujson as json from ._base import SQLBaseStore @@ -59,6 +60,7 @@ class EndToEndKeyStore(SQLBaseStore): "set_e2e_device_keys", _set_e2e_device_keys_txn ) + @defer.inlineCallbacks def get_e2e_device_keys(self, query_list, include_all_devices=False): """Fetch a list of device keys. Args: @@ -70,13 +72,19 @@ class EndToEndKeyStore(SQLBaseStore): dict containing "key_json", "device_display_name". """ if not query_list: - return {} + defer.returnValue({}) - return self.runInteraction( + results = yield self.runInteraction( "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list, include_all_devices, ) + for user_id, device_keys in results.iteritems(): + for device_id, device_info in device_keys.iteritems(): + device_info["keys"] = json.loads(device_info.pop("key_json")) + + defer.returnValue(results) + def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices): query_clauses = [] query_params = [] diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index bfa6294250..84ce492a2c 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -33,7 +33,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_key_without_device_name(self): now = 1470174257070 - json = '{ "key": "value" }' + json = {"key": "value"} yield self.store.store_device( "user", "device", None @@ -47,14 +47,14 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): self.assertIn("device", res["user"]) dev = res["user"]["device"] self.assertDictContainsSubset({ - "key_json": json, + "keys": json, "device_display_name": None, }, dev) @defer.inlineCallbacks def test_get_key_with_device_name(self): now = 1470174257070 - json = '{ "key": "value" }' + json = {"key": "value"} yield self.store.set_e2e_device_keys( "user", "device", now, json) @@ -67,7 +67,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): self.assertIn("device", res["user"]) dev = res["user"]["device"] self.assertDictContainsSubset({ - "key_json": json, + "keys": json, "device_display_name": "display_name", }, dev) -- cgit 1.4.1 From 1c13c9f6b6298b8f2688eb86b6a4f6af6b1d1fca Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 30 Jan 2017 17:12:14 +0000 Subject: Don't have such a large cache --- synapse/storage/roommember.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 6cf1a538ae..0fdcf29085 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -131,7 +131,7 @@ class RoomMemberStore(SQLBaseStore): with self._stream_id_gen.get_next() as stream_ordering: yield self.runInteraction("locally_reject_invite", f, stream_ordering) - @cached(max_entries=1000000, iterable=True) + @cached(max_entries=100000, iterable=True) def get_users_in_room(self, room_id): def f(txn): -- cgit 1.4.1 From 4b3403ca9b87a8187ea597027a82be9fe005cfb9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 30 Jan 2017 17:24:32 +0000 Subject: Stream cache invalidations for room membership storage functions --- synapse/storage/events.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 8712d7e18c..6685b9da1c 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -564,9 +564,13 @@ class EventsStore(SQLBaseStore): ) for member in members_changed: - txn.call_after(self.get_rooms_for_user.invalidate, (member,)) + self._invalidate_cache_and_stream( + txn, self.get_rooms_for_user, (member,) + ) - txn.call_after(self.get_users_in_room.invalidate, (room_id,)) + self._invalidate_cache_and_stream( + txn, self.get_users_in_room, (room_id,) + ) # Add an entry to the current_state_resets table to record the point # where we clobbered the current state -- cgit 1.4.1 From 05b9f48ee577f1cbdd5c5837f22c0d9cbe4c44cc Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 31 Jan 2017 10:08:55 +0000 Subject: Fix clearing out old device list outbound pokes --- synapse/storage/devices.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index f0353929da..e6fe67ee25 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -545,7 +545,7 @@ class DeviceStore(SQLBaseStore): (destination, user_id) tuple to ensure that the prev_ids remain correct if the server does come back. """ - now = self._clock.time_msec() + yesterday = self._clock.time_msec() - 24 * 60 * 60 * 1000 def _prune_txn(txn): select_sql = """ @@ -557,6 +557,9 @@ class DeviceStore(SQLBaseStore): txn.execute(select_sql) rows = txn.fetchall() + if not rows: + return + delete_sql = """ DELETE FROM device_lists_outbound_pokes WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ? @@ -565,11 +568,13 @@ class DeviceStore(SQLBaseStore): txn.executemany( delete_sql, ( - (now, row["destination"], row["user_id"], row["stream_id"]) + (yesterday, row[0], row[1], row[2]) for row in rows ) ) + logger.info("Pruned %d device list outbound pokes", txn.rowcount) + return self.runInteraction( "_prune_old_outbound_device_pokes", _prune_txn ) -- cgit 1.4.1 From d3169e8d28a4b7238256ff4d3151e3cc8feef0e1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 31 Jan 2017 11:20:03 +0000 Subject: Only fetch with row ts and count > 1 --- synapse/storage/devices.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index e6fe67ee25..cccefdd3d2 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -552,9 +552,10 @@ class DeviceStore(SQLBaseStore): SELECT destination, user_id, max(stream_id) as stream_id FROM device_lists_outbound_pokes GROUP BY destination, user_id + HAVING min(ts) < ? AND count(*) > 1 """ - txn.execute(select_sql) + txn.execute(select_sql, (yesterday,)) rows = txn.fetchall() if not rows: -- cgit 1.4.1 From ab55794b6f57988204605f3b1e7245a66e91dcec Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 31 Jan 2017 13:22:41 +0000 Subject: Fix deletion of old sent devices correctly --- synapse/storage/devices.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index cccefdd3d2..8e17800364 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -436,15 +436,27 @@ class DeviceStore(SQLBaseStore): ) def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): + # First we DELETE all rows such that only the latest row for each + # (destination, user_id is left. We do this by selecting first and + # deleting. + sql = """ + SELECT user_id, coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes + WHERE destination = ? AND stream_id <= ? + GROUP BY user_id + HAVING count(*) > 1 + """ + txn.execute(sql, (destination, stream_id,)) + rows = txn.fetchall() + sql = """ DELETE FROM device_lists_outbound_pokes - WHERE destination = ? AND stream_id < ( - SELECT coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes - WHERE destination = ? AND stream_id <= ? - ) + WHERE destination = ? AND user_id = ? AND stream_id < ? """ - txn.execute(sql, (destination, destination, stream_id,)) + txn.executemany( + sql, ((destination, row[0], row[1],) for row in rows) + ) + # Mark everything that is left as sent sql = """ UPDATE device_lists_outbound_pokes SET sent = ? WHERE destination = ? AND stream_id <= ? -- cgit 1.4.1 From ae7a132f38404e9f654ab1b7c5dd84ba6a3efda6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 31 Jan 2017 13:40:09 +0000 Subject: Better handle 404 response for federation /send/ --- synapse/federation/transaction_queue.py | 1 + synapse/util/retryutils.py | 15 +++++++++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index d18f6b6cfd..cb106c6a1b 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -319,6 +319,7 @@ class TransactionQueue(object): destination, self.clock, self.store, + backoff_on_404=True, # If we get a 404 the other side has gone ) device_message_edus, device_stream_id, dev_list_id = ( diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index e2de7fce91..cc88a0b532 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -88,7 +88,7 @@ class RetryDestinationLimiter(object): def __init__(self, destination, clock, store, retry_interval, min_retry_interval=10 * 60 * 1000, max_retry_interval=24 * 60 * 60 * 1000, - multiplier_retry_interval=5,): + multiplier_retry_interval=5, backoff_on_404=False): """Marks the destination as "down" if an exception is thrown in the context, except for CodeMessageException with code < 500. @@ -107,6 +107,7 @@ class RetryDestinationLimiter(object): a failed request, in milliseconds. multiplier_retry_interval (int): The multiplier to use to increase the retry interval after a failed request. + backoff_on_404 (bool): Back off if we get a 404 """ self.clock = clock self.store = store @@ -116,6 +117,7 @@ class RetryDestinationLimiter(object): self.min_retry_interval = min_retry_interval self.max_retry_interval = max_retry_interval self.multiplier_retry_interval = multiplier_retry_interval + self.backoff_on_404 = backoff_on_404 def __enter__(self): pass @@ -123,7 +125,16 @@ class RetryDestinationLimiter(object): def __exit__(self, exc_type, exc_val, exc_tb): valid_err_code = False if exc_type is not None and issubclass(exc_type, CodeMessageException): - valid_err_code = exc_val.code != 429 and 0 <= exc_val.code < 500 + if exc_val.code < 400: + valid_err_code = True + elif exc_val.code == 404 and self.backoff_on_404: + valid_err_code = False + elif exc_val.code == 429: + valid_err_code = False + elif exc_val.code < 500: + valid_err_code = True + else: + valid_err_code = False if exc_type is None or valid_err_code: # We connected successfully. -- cgit 1.4.1 From 85c590105f87a6cd138f1509f70087aa0881cf2d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 31 Jan 2017 13:46:38 +0000 Subject: Comment --- synapse/util/retryutils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index cc88a0b532..0961dd5b25 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -125,6 +125,10 @@ class RetryDestinationLimiter(object): def __exit__(self, exc_type, exc_val, exc_tb): valid_err_code = False if exc_type is not None and issubclass(exc_type, CodeMessageException): + # Some error codes are perfectly fine for some APIs, whereas other + # APIs may expect to never received e.g. a 404. It's important to + # handle 404 as some remote servers will return a 404 when the HS + # has been decommissioned. if exc_val.code < 400: valid_err_code = True elif exc_val.code == 404 and self.backoff_on_404: -- cgit 1.4.1 From 4c0ec15bdcb8fbecf5e4f6cdd3017c9c53076972 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 31 Jan 2017 13:53:46 +0000 Subject: Comment --- synapse/util/retryutils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 0961dd5b25..5c7fc1afb4 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -134,6 +134,8 @@ class RetryDestinationLimiter(object): elif exc_val.code == 404 and self.backoff_on_404: valid_err_code = False elif exc_val.code == 429: + # 429 is us being aggresively rate limited, so lets rate limit + # ourselves. valid_err_code = False elif exc_val.code < 500: valid_err_code = True -- cgit 1.4.1 From 21b73757780cc8609e895cd851a3b5072c8a7e32 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 31 Jan 2017 15:15:41 +0000 Subject: Add an index to make membership queries faster --- synapse/storage/roommember.py | 2 +- synapse/storage/schema/delta/40/current_state_idx.sql | 17 +++++++++++++++++ synapse/storage/state.py | 8 ++++++++ 3 files changed, 26 insertions(+), 1 deletion(-) create mode 100644 synapse/storage/schema/delta/40/current_state_idx.sql diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 0fdcf29085..10f7c7a4bc 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -220,7 +220,7 @@ class RoomMemberStore(SQLBaseStore): " ON e.event_id = c.event_id" " AND m.room_id = c.room_id" " AND m.user_id = c.state_key" - " WHERE %s" + " WHERE c.type = 'm.room.member' AND %s" ) % (where_clause,) txn.execute(sql, args) diff --git a/synapse/storage/schema/delta/40/current_state_idx.sql b/synapse/storage/schema/delta/40/current_state_idx.sql new file mode 100644 index 0000000000..7ffa189f39 --- /dev/null +++ b/synapse/storage/schema/delta/40/current_state_idx.sql @@ -0,0 +1,17 @@ +/* Copyright 2017 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('current_state_members_idx', '{}'); diff --git a/synapse/storage/state.py b/synapse/storage/state.py index d1d653327c..1b3800eb6a 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -49,6 +49,7 @@ class StateStore(SQLBaseStore): STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" + CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" def __init__(self, hs): super(StateStore, self).__init__(hs) @@ -60,6 +61,13 @@ class StateStore(SQLBaseStore): self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state, ) + self.register_background_index_update( + self.CURRENT_STATE_INDEX_UPDATE_NAME, + index_name="current_state_events_member_index", + table="current_state_events", + columns=["state_key"], + where_clause="type='m.room.member'", + ) @defer.inlineCallbacks def get_state_groups_ids(self, room_id, event_ids): -- cgit 1.4.1 From fe08db2713cb35e1424034d58d750ebdc52cedbc Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 31 Jan 2017 15:21:32 +0000 Subject: Remove explicit < 400 check as apparently this is confusing --- synapse/util/retryutils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 5c7fc1afb4..b94ae369cf 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -129,9 +129,7 @@ class RetryDestinationLimiter(object): # APIs may expect to never received e.g. a 404. It's important to # handle 404 as some remote servers will return a 404 when the HS # has been decommissioned. - if exc_val.code < 400: - valid_err_code = True - elif exc_val.code == 404 and self.backoff_on_404: + if exc_val.code == 404 and self.backoff_on_404: valid_err_code = False elif exc_val.code == 429: # 429 is us being aggresively rate limited, so lets rate limit -- cgit 1.4.1 From 458b6f473314a81d7e671fc2fc8c30d3259924c4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 31 Jan 2017 16:09:03 +0000 Subject: Only invalidate membership caches based on the cache stream Before we completely invalidated get_users_in_room whenever we updated any current_state_events table. This was way too aggressive. --- synapse/replication/resource.py | 3 --- synapse/replication/slave/storage/events.py | 21 +++++---------------- synapse/storage/events.py | 20 -------------------- synapse/storage/roommember.py | 2 -- 4 files changed, 5 insertions(+), 41 deletions(-) diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index a30e647474..d8eb14592b 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -299,9 +299,6 @@ class ReplicationResource(Resource): "backward_ex_outliers", res.backward_ex_outliers, ("position", "event_id", "state_group"), ) - writer.write_header_and_rows( - "state_resets", res.state_resets, ("position",), - ) @defer.inlineCallbacks def presence(self, writer, current_token, request_streams): diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index b3f3bf7488..15a025a019 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -192,10 +192,6 @@ class SlavedEventStore(BaseSlavedStore): return result def process_replication(self, result): - state_resets = set( - r[0] for r in result.get("state_resets", {"rows": []})["rows"] - ) - stream = result.get("events") if stream: self._stream_id_gen.advance(int(stream["position"])) @@ -205,7 +201,7 @@ class SlavedEventStore(BaseSlavedStore): for row in stream["rows"]: self._process_replication_row( - row, backfilled=False, state_resets=state_resets + row, backfilled=False, ) stream = result.get("backfill") @@ -213,7 +209,7 @@ class SlavedEventStore(BaseSlavedStore): self._backfill_id_gen.advance(-int(stream["position"])) for row in stream["rows"]: self._process_replication_row( - row, backfilled=True, state_resets=state_resets + row, backfilled=True, ) stream = result.get("forward_ex_outliers") @@ -232,20 +228,15 @@ class SlavedEventStore(BaseSlavedStore): return super(SlavedEventStore, self).process_replication(result) - def _process_replication_row(self, row, backfilled, state_resets): - position = row[0] + def _process_replication_row(self, row, backfilled): internal = json.loads(row[1]) event_json = json.loads(row[2]) event = FrozenEvent(event_json, internal_metadata_dict=internal) self.invalidate_caches_for_event( - event, backfilled, reset_state=position in state_resets + event, backfilled, ) - def invalidate_caches_for_event(self, event, backfilled, reset_state): - if reset_state: - self.get_rooms_for_user.invalidate_all() - self.get_users_in_room.invalidate((event.room_id,)) - + def invalidate_caches_for_event(self, event, backfilled): self._invalidate_get_event_cache(event.event_id) self.get_latest_event_ids_in_room.invalidate((event.room_id,)) @@ -267,8 +258,6 @@ class SlavedEventStore(BaseSlavedStore): self._invalidate_get_event_cache(event.redacts) if event.type == EventTypes.Member: - self.get_rooms_for_user.invalidate((event.state_key,)) - self.get_users_in_room.invalidate((event.room_id,)) self._membership_stream_cache.entity_has_changed( event.state_key, event.internal_metadata.stream_ordering ) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 6685b9da1c..f4352b326b 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -572,14 +572,6 @@ class EventsStore(SQLBaseStore): txn, self.get_users_in_room, (room_id,) ) - # Add an entry to the current_state_resets table to record the point - # where we clobbered the current state - self._simple_insert_txn( - txn, - table="current_state_resets", - values={"event_stream_ordering": max_stream_order} - ) - for room_id, new_extrem in new_forward_extremeties.items(): self._simple_delete_txn( txn, @@ -1610,15 +1602,6 @@ class EventsStore(SQLBaseStore): else: upper_bound = current_forward_id - sql = ( - "SELECT event_stream_ordering FROM current_state_resets" - " WHERE ? < event_stream_ordering" - " AND event_stream_ordering <= ?" - " ORDER BY event_stream_ordering ASC" - ) - txn.execute(sql, (last_forward_id, upper_bound)) - state_resets = txn.fetchall() - sql = ( "SELECT event_stream_ordering, event_id, state_group" " FROM ex_outlier_stream" @@ -1630,7 +1613,6 @@ class EventsStore(SQLBaseStore): forward_ex_outliers = txn.fetchall() else: new_forward_events = [] - state_resets = [] forward_ex_outliers = [] sql = ( @@ -1670,7 +1652,6 @@ class EventsStore(SQLBaseStore): return AllNewEventsResult( new_forward_events, new_backfill_events, forward_ex_outliers, backward_ex_outliers, - state_resets, ) return self.runInteraction("get_all_new_events", get_all_new_events_txn) @@ -1896,5 +1877,4 @@ class EventsStore(SQLBaseStore): AllNewEventsResult = namedtuple("AllNewEventsResult", [ "new_forward_events", "new_backfill_events", "forward_ex_outliers", "backward_ex_outliers", - "state_resets" ]) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 0fdcf29085..845def8467 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -66,8 +66,6 @@ class RoomMemberStore(SQLBaseStore): ) for event in events: - txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,)) - txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) txn.call_after( self._membership_stream_cache.entity_has_changed, event.state_key, event.internal_metadata.stream_ordering -- cgit 1.4.1 From 692daf6f5439c3c4852934f3bc950ccac2ec6d92 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 31 Jan 2017 16:10:16 +0000 Subject: Remote membership tests for replication This is because it now relies of the caches stream, which only works on postgres. We are trying to test with sqlite. --- tests/replication/slave/storage/test_events.py | 43 -------------------------- 1 file changed, 43 deletions(-) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 6acb8ab758..105e1228bb 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -58,49 +58,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): def tearDown(self): [unpatch() for unpatch in self.unpatches] - @defer.inlineCallbacks - def test_room_members(self): - yield self.persist(type="m.room.create", key="", creator=USER_ID) - yield self.replicate() - yield self.check("get_rooms_for_user", (USER_ID,), []) - yield self.check("get_users_in_room", (ROOM_ID,), []) - - # Join the room. - join = yield self.persist(type="m.room.member", key=USER_ID, membership="join") - yield self.replicate() - yield self.check("get_rooms_for_user", (USER_ID,), [RoomsForUser( - room_id=ROOM_ID, - sender=USER_ID, - membership="join", - event_id=join.event_id, - stream_ordering=join.internal_metadata.stream_ordering, - )]) - yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID]) - - # Leave the room. - yield self.persist(type="m.room.member", key=USER_ID, membership="leave") - yield self.replicate() - yield self.check("get_rooms_for_user", (USER_ID,), []) - yield self.check("get_users_in_room", (ROOM_ID,), []) - - # Add some other user to the room. - join = yield self.persist(type="m.room.member", key=USER_ID_2, membership="join") - yield self.replicate() - yield self.check("get_rooms_for_user", (USER_ID_2,), [RoomsForUser( - room_id=ROOM_ID, - sender=USER_ID, - membership="join", - event_id=join.event_id, - stream_ordering=join.internal_metadata.stream_ordering, - )]) - yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2]) - - yield self.persist( - type="m.room.member", key=USER_ID, membership="join", - ) - yield self.replicate() - yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2, USER_ID]) - @defer.inlineCallbacks def test_get_latest_event_ids_in_room(self): create = yield self.persist(type="m.room.create", key="", creator=USER_ID) -- cgit 1.4.1 From 97479d0c5442f3a644b356c5dbc920bf2ca2c925 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 1 Feb 2017 10:30:03 +0000 Subject: Implement /keys/changes --- synapse/handlers/device.py | 16 +++++++++++++++ synapse/rest/client/v2_alpha/keys.py | 38 ++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 7245d14fab..4a28d95967 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -220,6 +220,22 @@ class DeviceHandler(BaseHandler): for host in hosts: self.federation_sender.send_device_messages(host) + @defer.inlineCallbacks + def get_user_ids_changed(self, user_id, from_device_key): + rooms = yield self.store.get_rooms_for_user(user_id) + room_ids = set(r.room_id for r in rooms) + + user_ids_changed = set() + changed = yield self.store.get_user_whose_devices_changed( + from_device_key + ) + for other_user_id in changed: + other_rooms = yield self.store.get_rooms_for_user(other_user_id) + if room_ids.intersection(e.room_id for e in other_rooms): + user_ids_changed.add(other_user_id) + + defer.returnValue(user_ids_changed) + @defer.inlineCallbacks def _incoming_device_list_update(self, origin, edu_content): user_id = edu_content["user_id"] diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 46789775b9..5080101f18 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -21,6 +21,8 @@ from synapse.api.errors import SynapseError from synapse.http.servlet import ( RestServlet, parse_json_object_from_request, parse_integer ) +from synapse.http.servlet import parse_string +from synapse.types import StreamToken from ._base import client_v2_patterns logger = logging.getLogger(__name__) @@ -149,6 +151,41 @@ class KeyQueryServlet(RestServlet): defer.returnValue((200, result)) +class KeyChangesServlet(RestServlet): + PATTERNS = client_v2_patterns( + "/keys/changes$", + releases=() + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ + super(KeyChangesServlet, self).__init__() + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + + @defer.inlineCallbacks + def on_GET(self, request): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + + from_token_string = parse_string(request, "from") + parse_string(request, "to") # We want to enforce they do pass us one. + + from_token = StreamToken.from_string(from_token_string) + + user_id = requester.user.to_string() + + changed = yield self.device_handler.get_user_ids_changed( + user_id, from_token.device_list_key, + ) + + defer.returnValue((200, { + "changed": changed + })) + + class OneTimeKeyServlet(RestServlet): """ POST /keys/claim HTTP/1.1 @@ -192,4 +229,5 @@ class OneTimeKeyServlet(RestServlet): def register_servlets(hs, http_server): KeyUploadServlet(hs).register(http_server) KeyQueryServlet(hs).register(http_server) + KeyChangesServlet(hs).register(http_server) OneTimeKeyServlet(hs).register(http_server) -- cgit 1.4.1 From acb501c46d75247329f49a1eef3baf6d8af0cba1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 1 Feb 2017 10:32:49 +0000 Subject: Comment --- synapse/rest/client/v2_alpha/keys.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 5080101f18..2e855e5e04 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -152,6 +152,14 @@ class KeyQueryServlet(RestServlet): class KeyChangesServlet(RestServlet): + """Returns the list of changes of keys between two stream tokens (may return + spurious results). + + GET /keys/changes?from=...&to=... + + 200 OK + { "changed": ["@foo:example.com"] } + """ PATTERNS = client_v2_patterns( "/keys/changes$", releases=() @@ -171,7 +179,10 @@ class KeyChangesServlet(RestServlet): requester = yield self.auth.get_user_by_req(request, allow_guest=True) from_token_string = parse_string(request, "from") - parse_string(request, "to") # We want to enforce they do pass us one. + + # We want to enforce they do pass us one, but we ignore it and return + # changes after the "to" as well as before. + parse_string(request, "to") from_token = StreamToken.from_string(from_token_string) -- cgit 1.4.1 From 5deaf9e30bcf37b765e80f08c242e74ca8ac93b3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 1 Feb 2017 10:39:41 +0000 Subject: Up get_latest_event_ids_in_room cache --- synapse/storage/event_federation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index f0aa2193fb..ee88c61954 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -129,7 +129,7 @@ class EventFederationStore(SQLBaseStore): room_id, ) - @cached() + @cached(max_entries=5000, iterable=True) def get_latest_event_ids_in_room(self, room_id): return self._simple_select_onecol( table="event_forward_extremities", -- cgit 1.4.1 From 368c88c4870f797cee7775acaa2caec2753b7f91 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 1 Feb 2017 10:42:37 +0000 Subject: Add a small cache get_all_new_events --- synapse/storage/events.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 6685b9da1c..332da25783 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -28,6 +28,7 @@ from synapse.util.metrics import Measure from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError from synapse.state import resolve_events +from synapse.util.caches.descriptors import cached from canonicaljson import encode_canonical_json from collections import deque, namedtuple, OrderedDict @@ -1579,6 +1580,7 @@ class EventsStore(SQLBaseStore): """The current minimum token that backfilled events have reached""" return -self._backfill_id_gen.get_current_token() + @cached(num_args=5, max_entries=10) def get_all_new_events(self, last_backfill_id, last_forward_id, current_backfill_id, current_forward_id, limit): """Get all the new events that have arrived at the server either as -- cgit 1.4.1 From f6124311fd8893a306e7443cd725b1c25b007d39 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 1 Feb 2017 11:59:17 +0000 Subject: Add m.room.member type to query --- synapse/storage/roommember.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 10f7c7a4bc..3a99dc2349 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -266,7 +266,7 @@ class RoomMemberStore(SQLBaseStore): " ON m.event_id = c.event_id " " AND m.room_id = c.room_id " " AND m.user_id = c.state_key" - " WHERE %(where)s" + " WHERE c.type = 'm.room.member' AND %(where)s" ) % { "where": where_clause, } -- cgit 1.4.1 From 73d676dc8b38e8b16d35b9557480117a6c072ef7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 1 Feb 2017 13:17:17 +0000 Subject: Comment --- synapse/rest/client/v2_alpha/keys.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 2e855e5e04..4590efa6bf 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -153,7 +153,7 @@ class KeyQueryServlet(RestServlet): class KeyChangesServlet(RestServlet): """Returns the list of changes of keys between two stream tokens (may return - spurious results). + spurious extra results, since we currently ignore the `to` param). GET /keys/changes?from=...&to=... -- cgit 1.4.1 From 6d6591880e48761abf26f772cb22a6e7bd0aa71d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 1 Feb 2017 15:15:16 +0000 Subject: Wake sync up for device changes --- synapse/handlers/sync.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 9199f20817..9a44de3f33 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -130,7 +130,8 @@ class SyncResult(collections.namedtuple("SyncResult", [ self.invited or self.archived or self.account_data or - self.to_device + self.to_device or + self.device_lists ) -- cgit 1.4.1 From df4ecff5a9f52e26e01ce364e964fc141c920e52 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 1 Feb 2017 15:42:19 +0000 Subject: Correctly raise exceptions for ratelimitng. Ratelimit on 401 --- synapse/federation/transaction_queue.py | 2 +- synapse/util/retryutils.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index cb106c6a1b..bb3d9258a6 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -504,7 +504,7 @@ class TransactionQueue(object): code = e.code response = e.response - if e.code == 429 or 500 <= e.code: + if e.code in (401, 404, 429) or 500 <= e.code: logger.info( "TX [%s] {%s} got %d response", destination, txn_id, code diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index b94ae369cf..153ef001ad 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -129,11 +129,13 @@ class RetryDestinationLimiter(object): # APIs may expect to never received e.g. a 404. It's important to # handle 404 as some remote servers will return a 404 when the HS # has been decommissioned. + # If we get a 401, then we should probably back off since they + # won't accept our requests for at least a while. + # 429 is us being aggresively rate limited, so lets rate limit + # ourselves. if exc_val.code == 404 and self.backoff_on_404: valid_err_code = False - elif exc_val.code == 429: - # 429 is us being aggresively rate limited, so lets rate limit - # ourselves. + elif exc_val.code in (401, 429): valid_err_code = False elif exc_val.code < 500: valid_err_code = True -- cgit 1.4.1 From 7e919bdbd09bf200d2e27767450eacbfbf2f4c3f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 1 Feb 2017 17:33:16 +0000 Subject: Include newly joined users in /keys/changes API --- synapse/handlers/device.py | 39 ++++++++++++++++++++++++++++++++---- synapse/rest/client/v2_alpha/keys.py | 2 +- synapse/storage/stream.py | 7 +++++++ 3 files changed, 43 insertions(+), 5 deletions(-) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 4a28d95967..4589dab409 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -14,6 +14,7 @@ # limitations under the License. from synapse.api import errors +from synapse.api.constants import EventTypes from synapse.util import stringutils from synapse.util.async import Linearizer from synapse.types import get_domain_from_id @@ -221,15 +222,45 @@ class DeviceHandler(BaseHandler): self.federation_sender.send_device_messages(host) @defer.inlineCallbacks - def get_user_ids_changed(self, user_id, from_device_key): + def get_user_ids_changed(self, user_id, from_token): rooms = yield self.store.get_rooms_for_user(user_id) room_ids = set(r.room_id for r in rooms) - user_ids_changed = set() + # First we check if any devices have changed changed = yield self.store.get_user_whose_devices_changed( - from_device_key + from_token.device_list_key ) - for other_user_id in changed: + + # Then work out if any users have since joined + rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key) + + possibly_changed = set(changed) + for room_id in rooms_changed: + # Fetch (an approximation) of the current state at the time. + event_rows, token = yield self.store.get_recent_event_ids_for_room( + room_id, end_token=from_token.room_key, limit=1, + ) + + if event_rows: + last_event_id = event_rows[-1]["event_id"] + prev_state_ids = yield self.store.get_state_ids_for_event(last_event_id) + else: + prev_state_ids = {} + + current_state_ids = yield self.state.get_current_state_ids(room_id) + + # If there has been any change in membership, include them in the + # possibly changed list. We'll check if they are joined below, + # and we're not toooo worried about spuriously adding users. + for key, event_id in current_state_ids.iteritems(): + etype, state_key = key + if etype == EventTypes.Member: + prev_event_id = prev_state_ids.get(key, None) + if not prev_event_id or prev_event_id != event_id: + possibly_changed.add(state_key) + + user_ids_changed = set() + for other_user_id in possibly_changed: other_rooms = yield self.store.get_rooms_for_user(other_user_id) if room_ids.intersection(e.room_id for e in other_rooms): user_ids_changed.add(other_user_id) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 4590efa6bf..f99b53530a 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -189,7 +189,7 @@ class KeyChangesServlet(RestServlet): user_id = requester.user.to_string() changed = yield self.device_handler.get_user_ids_changed( - user_id, from_token.device_list_key, + user_id, from_token, ) defer.returnValue((200, { diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 2dc24951c4..cdc1838895 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -244,6 +244,13 @@ class StreamStore(SQLBaseStore): defer.returnValue(results) + def get_rooms_that_changed(self, room_ids, from_key): + from_key = RoomStreamToken.parse_stream_token(from_key).stream + return set( + room_id for room_id in room_ids + if self._events_stream_cache.has_entity_changed(room_id, from_key) + ) + @defer.inlineCallbacks def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0, order='DESC'): -- cgit 1.4.1 From d61a04583eda3ba4deea4b82b93e61903919e1a8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 1 Feb 2017 17:35:23 +0000 Subject: Comment --- synapse/storage/stream.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index cdc1838895..3765d0095c 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -245,6 +245,9 @@ class StreamStore(SQLBaseStore): defer.returnValue(results) def get_rooms_that_changed(self, room_ids, from_key): + """Given a list of rooms and a token, return rooms where there may have + been changes. + """ from_key = RoomStreamToken.parse_stream_token(from_key).stream return set( room_id for room_id in room_ids -- cgit 1.4.1 From fbfe44bb4de0e490b5c34ebdb1e8c0c09dd766b1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 1 Feb 2017 17:52:57 +0000 Subject: Doc args --- synapse/handlers/device.py | 7 +++++++ synapse/storage/stream.py | 4 ++++ 2 files changed, 11 insertions(+) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 4589dab409..815410969c 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -223,6 +223,13 @@ class DeviceHandler(BaseHandler): @defer.inlineCallbacks def get_user_ids_changed(self, user_id, from_token): + """Get list of users that have had the devices updated, or have newly + joined a room, that `user_id` may be interested in. + + Args: + user_id (str) + from_token (StreamToken) + """ rooms = yield self.store.get_rooms_for_user(user_id) room_ids = set(r.room_id for r in rooms) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 3765d0095c..200d124632 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -247,6 +247,10 @@ class StreamStore(SQLBaseStore): def get_rooms_that_changed(self, room_ids, from_key): """Given a list of rooms and a token, return rooms where there may have been changes. + + Args: + room_ids (list) + from_key (str): The room_key portion of a StreamToken """ from_key = RoomStreamToken.parse_stream_token(from_key).stream return set( -- cgit 1.4.1 From 10e0737569a87e4f42f2ff64021564fa51539d89 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 1 Feb 2017 16:31:11 +0000 Subject: Bump version and changelog --- CHANGES.rst | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ synapse/__init__.py | 2 +- 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/CHANGES.rst b/CHANGES.rst index 9106134b46..b927a2a285 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,3 +1,51 @@ +Changes in synapse v0.19.0-rc1 (2017-02-02) +=========================================== + +Features: + +* Add support for specifying multiple bind addresses (PR #1709, #1712, #1795, + #1835) +* Add /account/3pid/delete endpoint (PR #1714) +* Add config option to configure the Riot URL used in notification emails (PR + #1811) +* Add username and password config options for turn server (PR #1832) +* Implement device lists updates over federation (PR #1857, #1861, #1864) +* Implement /keys/changes (PR #1869, #1872) + + +Changes: + +* IPv6 support (PR #1696)) +* Log which files we saved attachments to in the media_repository (PR #1791) +* Linearize updates to membership via PUT /state/ to better handle multiple + joins (PR #1787) +* Limit number of entries to prefill from cache on startup (PR #1792) +* Remove full_twisted_stacktraces option (PR #1802) +* Measure size of some caches by sum of the size of cached values (PR #1815) +* Measure metrics of string_cache (PR #1821) +* Reduce logging verbosity (PR #1822, #1823, #1824) +* Don't clobber a displayname or avatar_url if provided by an m.room.member + event (PR #1852) +* Better handle 401/404 response for federation /send/ (PR #1866, #1871) + + +Fixes: + +* Fix ability to change password to a non-ascii one (PR #1711) +* Fix push getting stuck due to looking at the wrong view of state (PR #1820) +* Fix email address comparison to be case insensitive (PR #1827) +* Fix occasional inconsistencies of room membership (PR #1836, #1840) + + +Performance: + +* Don't block messages sending on bumping presence (PR #1789) +* Change device_inbox stream index to include user (PR #1793) +* Optimise state resolution (PR #1818) +* Use DB cache of joined users for presence (PR #1862) +* Add an index to make membership queries faster (PR #1867) + + Changes in synapse v0.18.7 (2017-01-09) ======================================= diff --git a/synapse/__init__.py b/synapse/__init__.py index 498ded38c0..a053a02adb 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.18.7" +__version__ = "0.19.0-rc1" -- cgit 1.4.1 From 51adaac953c00ee59101a71de6162cde4a0e0a86 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 2 Feb 2017 10:53:36 +0000 Subject: Fix email push in pusher worker This was broken when device list updates were implemented, as Mailer could no longer instantiate an AuthHandler due to a dependency on federation sending. --- synapse/handlers/auth.py | 80 ++++++++++++++++++-------------- synapse/handlers/register.py | 10 ++-- synapse/push/mailer.py | 4 +- synapse/rest/client/v1/login.py | 5 +- synapse/rest/client/v2_alpha/register.py | 3 +- synapse/server.py | 6 ++- tests/handlers/test_auth.py | 12 ++--- tests/handlers/test_register.py | 7 +-- 8 files changed, 70 insertions(+), 57 deletions(-) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 221d7ea7a2..fffba34383 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -65,6 +65,7 @@ class AuthHandler(BaseHandler): self.hs = hs # FIXME better possibility to access registrationHandler later? self.device_handler = hs.get_device_handler() + self.macaroon_gen = hs.get_macaroon_generator() @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): @@ -529,37 +530,11 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def issue_access_token(self, user_id, device_id=None): - access_token = self.generate_access_token(user_id) + access_token = self.macaroon_gen.generate_access_token(user_id) yield self.store.add_access_token_to_user(user_id, access_token, device_id) defer.returnValue(access_token) - def generate_access_token(self, user_id, extra_caveats=None): - extra_caveats = extra_caveats or [] - macaroon = self._generate_base_macaroon(user_id) - macaroon.add_first_party_caveat("type = access") - # Include a nonce, to make sure that each login gets a different - # access token. - macaroon.add_first_party_caveat("nonce = %s" % ( - stringutils.random_string_with_symbols(16), - )) - for caveat in extra_caveats: - macaroon.add_first_party_caveat(caveat) - return macaroon.serialize() - - def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)): - macaroon = self._generate_base_macaroon(user_id) - macaroon.add_first_party_caveat("type = login") - now = self.hs.get_clock().time_msec() - expiry = now + duration_in_ms - macaroon.add_first_party_caveat("time < %d" % (expiry,)) - return macaroon.serialize() - - def generate_delete_pusher_token(self, user_id): - macaroon = self._generate_base_macaroon(user_id) - macaroon.add_first_party_caveat("type = delete_pusher") - return macaroon.serialize() - def validate_short_term_login_token_and_get_user_id(self, login_token): auth_api = self.hs.get_auth() try: @@ -570,15 +545,6 @@ class AuthHandler(BaseHandler): except Exception: raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) - def _generate_base_macaroon(self, user_id): - macaroon = pymacaroons.Macaroon( - location=self.hs.config.server_name, - identifier="key", - key=self.hs.config.macaroon_secret_key) - macaroon.add_first_party_caveat("gen = 1") - macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) - return macaroon - @defer.inlineCallbacks def set_password(self, user_id, newpassword, requester=None): password_hash = self.hash(newpassword) @@ -673,6 +639,48 @@ class AuthHandler(BaseHandler): return False +class MacaroonGeneartor(object): + def __init__(self, hs): + self.clock = hs.get_clock() + self.server_name = hs.config.server_name + self.macaroon_secret_key = hs.config.macaroon_secret_key + + def generate_access_token(self, user_id, extra_caveats=None): + extra_caveats = extra_caveats or [] + macaroon = self._generate_base_macaroon(user_id) + macaroon.add_first_party_caveat("type = access") + # Include a nonce, to make sure that each login gets a different + # access token. + macaroon.add_first_party_caveat("nonce = %s" % ( + stringutils.random_string_with_symbols(16), + )) + for caveat in extra_caveats: + macaroon.add_first_party_caveat(caveat) + return macaroon.serialize() + + def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)): + macaroon = self._generate_base_macaroon(user_id) + macaroon.add_first_party_caveat("type = login") + now = self.clock.time_msec() + expiry = now + duration_in_ms + macaroon.add_first_party_caveat("time < %d" % (expiry,)) + return macaroon.serialize() + + def generate_delete_pusher_token(self, user_id): + macaroon = self._generate_base_macaroon(user_id) + macaroon.add_first_party_caveat("type = delete_pusher") + return macaroon.serialize() + + def _generate_base_macaroon(self, user_id): + macaroon = pymacaroons.Macaroon( + location=self.server_name, + identifier="key", + key=self.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) + return macaroon + + class _AccountHandler(object): """A proxy object that gets passed to password auth providers so they can register new users etc if necessary. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 286f0cef0a..03c6a85fc6 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -40,6 +40,8 @@ class RegistrationHandler(BaseHandler): self._next_generated_user_id = None + self.macaroon_gen = hs.get_macaroon_generator() + @defer.inlineCallbacks def check_username(self, localpart, guest_access_token=None, assigned_user_id=None): @@ -143,7 +145,7 @@ class RegistrationHandler(BaseHandler): token = None if generate_token: - token = self.auth_handler().generate_access_token(user_id) + token = self.macaroon_gen.generate_access_token(user_id) yield self.store.register( user_id=user_id, token=token, @@ -167,7 +169,7 @@ class RegistrationHandler(BaseHandler): user_id = user.to_string() yield self.check_user_id_not_appservice_exclusive(user_id) if generate_token: - token = self.auth_handler().generate_access_token(user_id) + token = self.macaroon_gen.generate_access_token(user_id) try: yield self.store.register( user_id=user_id, @@ -254,7 +256,7 @@ class RegistrationHandler(BaseHandler): user_id = user.to_string() yield self.check_user_id_not_appservice_exclusive(user_id) - token = self.auth_handler().generate_access_token(user_id) + token = self.macaroon_gen.generate_access_token(user_id) try: yield self.store.register( user_id=user_id, @@ -399,7 +401,7 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - token = self.auth_handler().generate_access_token(user_id) + token = self.macaroon_gen.generate_access_token(user_id) if need_register: yield self.store.register( diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index ce2d31fb98..62d794f22b 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -81,7 +81,7 @@ class Mailer(object): def __init__(self, hs, app_name): self.hs = hs self.store = self.hs.get_datastore() - self.auth_handler = self.hs.get_auth_handler() + self.macaroon_gen = self.hs.get_macaroon_generator() self.state_handler = self.hs.get_state_handler() loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir) self.app_name = app_name @@ -466,7 +466,7 @@ class Mailer(object): def make_unsubscribe_link(self, user_id, app_id, email_address): params = { - "access_token": self.auth_handler.generate_delete_pusher_token(user_id), + "access_token": self.macaroon_gen.generate_delete_pusher_token(user_id), "app_id": app_id, "pushkey": email_address, } diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 0c9cdff3b8..72057f1b0c 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -330,6 +330,7 @@ class CasTicketServlet(ClientV1RestServlet): self.cas_required_attributes = hs.config.cas_required_attributes self.auth_handler = hs.get_auth_handler() self.handlers = hs.get_handlers() + self.macaroon_gen = hs.get_macaroon_generator() @defer.inlineCallbacks def on_GET(self, request): @@ -368,7 +369,9 @@ class CasTicketServlet(ClientV1RestServlet): yield self.handlers.registration_handler.register(localpart=user) ) - login_token = auth_handler.generate_short_term_login_token(registered_user_id) + login_token = self.macaroon_gen.generate_short_term_login_token( + registered_user_id + ) redirect_url = self.add_login_token_to_redirect_url(client_redirect_url, login_token) request.redirect(redirect_url) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 3e7a285e10..ccca5a12d5 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -96,6 +96,7 @@ class RegisterRestServlet(RestServlet): self.registration_handler = hs.get_handlers().registration_handler self.identity_handler = hs.get_handlers().identity_handler self.device_handler = hs.get_device_handler() + self.macaroon_gen = hs.get_macaroon_generator() @defer.inlineCallbacks def on_POST(self, request): @@ -436,7 +437,7 @@ class RegisterRestServlet(RestServlet): user_id, device_id, initial_display_name ) - access_token = self.auth_handler.generate_access_token( + access_token = self.macaroon_gen.generate_access_token( user_id, ["guest = true"] ) defer.returnValue((200, { diff --git a/synapse/server.py b/synapse/server.py index 0bfb411269..c577032041 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -37,7 +37,7 @@ from synapse.federation.transport.client import TransportLayerClient from synapse.federation.transaction_queue import TransactionQueue from synapse.handlers import Handlers from synapse.handlers.appservice import ApplicationServicesHandler -from synapse.handlers.auth import AuthHandler +from synapse.handlers.auth import AuthHandler, MacaroonGeneartor from synapse.handlers.devicemessage import DeviceMessageHandler from synapse.handlers.device import DeviceHandler from synapse.handlers.e2e_keys import E2eKeysHandler @@ -131,6 +131,7 @@ class HomeServer(object): 'federation_transport_client', 'federation_sender', 'receipts_handler', + 'macaroon_generator', ] def __init__(self, hostname, **kwargs): @@ -213,6 +214,9 @@ class HomeServer(object): def build_auth_handler(self): return AuthHandler(self) + def build_macaroon_generator(self): + return MacaroonGeneartor(self) + def build_device_handler(self): return DeviceHandler(self) diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 9d013e5ca7..1822dcf1e0 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -34,11 +34,10 @@ class AuthTestCase(unittest.TestCase): self.hs = yield setup_test_homeserver(handlers=None) self.hs.handlers = AuthHandlers(self.hs) self.auth_handler = self.hs.handlers.auth_handler + self.macaroon_generator = self.hs.get_macaroon_generator() def test_token_is_a_macaroon(self): - self.hs.config.macaroon_secret_key = "this key is a huge secret" - - token = self.auth_handler.generate_access_token("some_user") + token = self.macaroon_generator.generate_access_token("some_user") # Check that we can parse the thing with pymacaroons macaroon = pymacaroons.Macaroon.deserialize(token) # The most basic of sanity checks @@ -46,10 +45,9 @@ class AuthTestCase(unittest.TestCase): self.fail("some_user was not in %s" % macaroon.inspect()) def test_macaroon_caveats(self): - self.hs.config.macaroon_secret_key = "this key is a massive secret" self.hs.clock.now = 5000 - token = self.auth_handler.generate_access_token("a_user") + token = self.macaroon_generator.generate_access_token("a_user") macaroon = pymacaroons.Macaroon.deserialize(token) def verify_gen(caveat): @@ -74,7 +72,7 @@ class AuthTestCase(unittest.TestCase): def test_short_term_login_token_gives_user_id(self): self.hs.clock.now = 1000 - token = self.auth_handler.generate_short_term_login_token( + token = self.macaroon_generator.generate_short_term_login_token( "a_user", 5000 ) @@ -93,7 +91,7 @@ class AuthTestCase(unittest.TestCase): ) def test_short_term_login_token_cannot_replace_user_id(self): - token = self.auth_handler.generate_short_term_login_token( + token = self.macaroon_generator.generate_short_term_login_token( "a_user", 5000 ) macaroon = pymacaroons.Macaroon.deserialize(token) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index a4380c48b4..c8cf9a63ec 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -41,15 +41,12 @@ class RegistrationTestCase(unittest.TestCase): handlers=None, http_client=None, expire_access_token=True) - self.auth_handler = Mock( + self.macaroon_generator = Mock( generate_access_token=Mock(return_value='secret')) + self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator) self.hs.handlers = RegistrationHandlers(self.hs) self.handler = self.hs.get_handlers().registration_handler self.hs.get_handlers().profile_handler = Mock() - self.mock_handler = Mock(spec=[ - "generate_access_token", - ]) - self.hs.get_auth_handler = Mock(return_value=self.auth_handler) @defer.inlineCallbacks def test_user_is_created_and_logged_in_if_doesnt_exist(self): -- cgit 1.4.1 From 85e98fd4e83b8e34e867999adbe8b772c3ca9ff6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 2 Feb 2017 11:02:32 +0000 Subject: Update changelog --- CHANGES.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGES.rst b/CHANGES.rst index b927a2a285..979be4ae68 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -15,7 +15,7 @@ Features: Changes: -* IPv6 support (PR #1696)) +* Improve IPv6 support (PR #1696) * Log which files we saved attachments to in the media_repository (PR #1791) * Linearize updates to membership via PUT /state/ to better handle multiple joins (PR #1787) -- cgit 1.4.1 From bfe3f5815faa230692ec9090baaea17224bd8a4c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 2 Feb 2017 11:10:41 +0000 Subject: Update changelog --- CHANGES.rst | 6 ++++++ synapse/__init__.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/CHANGES.rst b/CHANGES.rst index 979be4ae68..8473d7d48c 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,3 +1,9 @@ +Changes in synapse v0.19.0-rc2 (2017-02-02) +=========================================== + +* Include newly joined users in /keys/changes API (PR #1872) + + Changes in synapse v0.19.0-rc1 (2017-02-02) =========================================== diff --git a/synapse/__init__.py b/synapse/__init__.py index a053a02adb..d398799579 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.19.0-rc1" +__version__ = "0.19.0-rc2" -- cgit 1.4.1 From 54a79c1d374f09049d3eb8ac531bca45d68b5f2b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 2 Feb 2017 13:07:18 +0000 Subject: Make presence.get_new_events a bit faster We do this by caching the set of users a user shares rooms with. --- synapse/handlers/presence.py | 44 +++++++++++++++---------------------------- synapse/handlers/room.py | 1 + synapse/notifier.py | 1 + synapse/storage/roommember.py | 16 ++++++++++++++++ 4 files changed, 33 insertions(+), 29 deletions(-) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 9982ae0fed..fdfce2a88c 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1011,7 +1011,7 @@ class PresenceEventSource(object): @defer.inlineCallbacks @log_function def get_new_events(self, user, from_key, room_ids=None, include_offline=True, - **kwargs): + explicit_room_id=None, **kwargs): # The process for getting presence events are: # 1. Get the rooms the user is in. # 2. Get the list of user in the rooms. @@ -1028,22 +1028,24 @@ class PresenceEventSource(object): user_id = user.to_string() if from_key is not None: from_key = int(from_key) - room_ids = room_ids or [] presence = self.get_presence_handler() stream_change_cache = self.store.presence_stream_cache - if not room_ids: - rooms = yield self.store.get_rooms_for_user(user_id) - room_ids = set(e.room_id for e in rooms) - else: - room_ids = set(room_ids) - max_token = self.store.get_current_presence_token() plist = yield self.store.get_presence_list_accepted(user.localpart) - friends = set(row["observed_user_id"] for row in plist) - friends.add(user_id) # So that we receive our own presence + users_interested_in = set(row["observed_user_id"] for row in plist) + users_interested_in.add(user_id) # So that we receive our own presence + + users_who_share_room = yield self.store.get_users_who_share_room_with_user( + user_id + ) + users_interested_in.update(users_who_share_room) + + if explicit_room_id: + user_ids = yield self.store.get_users_in_room(explicit_room_id) + users_interested_in.update(user_ids) user_ids_changed = set() changed = None @@ -1055,35 +1057,19 @@ class PresenceEventSource(object): # work out if we share a room or they're in our presence list get_updates_counter.inc("stream") for other_user_id in changed: - if other_user_id in friends: + if other_user_id in users_interested_in: user_ids_changed.add(other_user_id) - continue - other_rooms = yield self.store.get_rooms_for_user(other_user_id) - if room_ids.intersection(e.room_id for e in other_rooms): - user_ids_changed.add(other_user_id) - continue else: # Too many possible updates. Find all users we can see and check # if any of them have changed. get_updates_counter.inc("full") - user_ids_to_check = set() - for room_id in room_ids: - users = yield self.store.get_users_in_room(room_id) - user_ids_to_check.update(users) - - user_ids_to_check.update(friends) - - # Always include yourself. Only really matters for when the user is - # not in any rooms, but still. - user_ids_to_check.add(user_id) - if from_key: user_ids_changed = stream_change_cache.get_entities_changed( - user_ids_to_check, from_key, + users_interested_in, from_key, ) else: - user_ids_changed = user_ids_to_check + user_ids_changed = users_interested_in updates = yield presence.current_state_for_users(user_ids_changed) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 5f18007e90..7e7671c9a2 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -437,6 +437,7 @@ class RoomEventSource(object): limit, room_ids, is_guest, + explicit_room_id=None, ): # We just ignore the key for now. diff --git a/synapse/notifier.py b/synapse/notifier.py index acbd4bb5ae..8051a7a842 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -378,6 +378,7 @@ class Notifier(object): limit=limit, is_guest=is_peeking, room_ids=room_ids, + explicit_room_id=explicit_room_id, ) if name == "room": diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index ee800d074f..70718f41ed 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -280,6 +280,22 @@ class RoomMemberStore(SQLBaseStore): user_id, membership_list=[Membership.JOIN], ) + @cachedInlineCallbacks(max_entries=50000, cache_context=True, iterable=True) + def get_users_who_share_room_with_user(self, user_id, cache_context): + rooms = yield self.get_rooms_for_user( + user_id, on_invalidate=cache_context.invalidate, + ) + + user_who_share_room = set() + for room in rooms: + user_ids = yield self.get_users_in_room( + room.room_id, on_invalidate=cache_context.invalidate, + ) + logger.info("Users in room: %r %r", room.room_id, user_ids) + user_who_share_room.update(user_ids) + + defer.returnValue(user_who_share_room) + def forget(self, user_id, room_id): """Indicate that user_id wishes to discard history for room_id.""" def f(txn): -- cgit 1.4.1 From 832e9c52ca9dd349e28967727ae87a484e6ce557 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 2 Feb 2017 13:09:56 +0000 Subject: Comment --- synapse/storage/roommember.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 70718f41ed..249217e114 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -282,6 +282,8 @@ class RoomMemberStore(SQLBaseStore): @cachedInlineCallbacks(max_entries=50000, cache_context=True, iterable=True) def get_users_who_share_room_with_user(self, user_id, cache_context): + """Returns the set of users who share a room with `user_id` + """ rooms = yield self.get_rooms_for_user( user_id, on_invalidate=cache_context.invalidate, ) @@ -291,7 +293,6 @@ class RoomMemberStore(SQLBaseStore): user_ids = yield self.get_users_in_room( room.room_id, on_invalidate=cache_context.invalidate, ) - logger.info("Users in room: %r %r", room.room_id, user_ids) user_who_share_room.update(user_ids) defer.returnValue(user_who_share_room) -- cgit 1.4.1 From 9efcc3f3be1f56be8ffdb3172d6908d55028cc61 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 2 Feb 2017 13:50:22 +0000 Subject: Comment --- synapse/util/caches/descriptors.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 675bfd5feb..3c6838df16 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -478,6 +478,8 @@ class CacheListDescriptor(object): class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): + # We rely on _CacheContext implementing __eq__ and __hash__ sensibly, + # which namedtuple does for us. def invalidate(self): self.cache.invalidate(self.key) -- cgit 1.4.1 From 46ecd9fd6d4ea2786cde0f2576aa28421be40047 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 2 Feb 2017 14:27:19 +0000 Subject: Use stream_ordering_to_exterm for /keys/changes --- synapse/handlers/device.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 815410969c..6c1b945ff3 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -17,7 +17,7 @@ from synapse.api import errors from synapse.api.constants import EventTypes from synapse.util import stringutils from synapse.util.async import Linearizer -from synapse.types import get_domain_from_id +from synapse.types import get_domain_from_id, RoomStreamToken from twisted.internet import defer from ._base import BaseHandler @@ -243,15 +243,15 @@ class DeviceHandler(BaseHandler): possibly_changed = set(changed) for room_id in rooms_changed: - # Fetch (an approximation) of the current state at the time. - event_rows, token = yield self.store.get_recent_event_ids_for_room( - room_id, end_token=from_token.room_key, limit=1, - ) + # Fetch the current state at the time. + stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key) - if event_rows: - last_event_id = event_rows[-1]["event_id"] - prev_state_ids = yield self.store.get_state_ids_for_event(last_event_id) - else: + try: + event_ids = yield self.store.get_forward_extremeties_for_room( + room_id, stream_ordering=stream_ordering + ) + prev_state_ids = yield self.store.get_state_ids_for_events(event_ids) + except: prev_state_ids = {} current_state_ids = yield self.state.get_current_state_ids(room_id) -- cgit 1.4.1 From 6b61060b51970bd170fffd442df0e9f02ddcf678 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 2 Feb 2017 14:47:15 +0000 Subject: Comment --- synapse/util/caches/descriptors.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 3c6838df16..998de70d29 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -479,7 +479,10 @@ class CacheListDescriptor(object): class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): # We rely on _CacheContext implementing __eq__ and __hash__ sensibly, - # which namedtuple does for us. + # which namedtuple does for us (i.e. two _CacheContext are the same if + # their caches and keys match). This is important in particular to + # dedupe when we add callbacks to lru cache nodes, otherwise the number + # of callbacks would grow. def invalidate(self): self.cache.invalidate(self.key) -- cgit 1.4.1 From 6826593b8168d648b74a4d1c45ebe5aa66588d8e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 2 Feb 2017 14:55:54 +0000 Subject: sets aren't JSON serializable --- synapse/rest/client/v2_alpha/keys.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index f99b53530a..6a3cfe84f8 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -193,7 +193,7 @@ class KeyChangesServlet(RestServlet): ) defer.returnValue((200, { - "changed": changed + "changed": list(changed), })) -- cgit 1.4.1 From 0f3e296cb7d63aa8b4cfcaa54a0b2a63fbe7c943 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 2 Feb 2017 15:02:03 +0000 Subject: Fix replication --- synapse/replication/slave/storage/events.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 15a025a019..d72ff6055c 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -73,6 +73,9 @@ class SlavedEventStore(BaseSlavedStore): # to reach inside the __dict__ to extract them. get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"] get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"] + get_users_who_share_room_with_user = ( + RoomMemberStore.__dict__["get_users_who_share_room_with_user"] + ) get_latest_event_ids_in_room = EventFederationStore.__dict__[ "get_latest_event_ids_in_room" ] -- cgit 1.4.1 From 1232ae41cf1ee4e66025b0db3460e339ef3b9971 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 2 Feb 2017 15:25:00 +0000 Subject: Use new get_users_who_share_room_with_user --- synapse/handlers/device.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 6c1b945ff3..158206aef6 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -198,20 +198,22 @@ class DeviceHandler(BaseHandler): """Notify that a user's device(s) has changed. Pokes the notifier, and remote servers if the user is local. """ - rooms = yield self.store.get_rooms_for_user(user_id) - room_ids = [r.room_id for r in rooms] + users_who_share_room = yield self.store.get_users_who_share_room_with_user( + user_id + ) hosts = set() if self.hs.is_mine_id(user_id): - for room_id in room_ids: - users = yield self.store.get_users_in_room(room_id) - hosts.update(get_domain_from_id(u) for u in users) + hosts.update(get_domain_from_id(u) for u in users_who_share_room) hosts.discard(self.server_name) position = yield self.store.add_device_change_to_streams( user_id, device_ids, list(hosts) ) + rooms = yield self.store.get_rooms_for_user(user_id) + room_ids = [r.room_id for r in rooms] + yield self.notifier.on_new_event( "device_list_key", position, rooms=room_ids, ) @@ -266,13 +268,13 @@ class DeviceHandler(BaseHandler): if not prev_event_id or prev_event_id != event_id: possibly_changed.add(state_key) - user_ids_changed = set() - for other_user_id in possibly_changed: - other_rooms = yield self.store.get_rooms_for_user(other_user_id) - if room_ids.intersection(e.room_id for e in other_rooms): - user_ids_changed.add(other_user_id) + users_who_share_room = yield self.store.get_users_who_share_room_with_user( + user_id + ) - defer.returnValue(user_ids_changed) + # Take the intersection of the users whose devices may have changed + # and those that actually still share a room with the user + defer.returnValue(users_who_share_room & possibly_changed) @defer.inlineCallbacks def _incoming_device_list_update(self, origin, edu_content): -- cgit 1.4.1 From 82b3e0851c1bd8569a37231d9dac7430294b8096 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 2 Feb 2017 17:15:17 +0000 Subject: Bump version and changelog --- CHANGES.rst | 17 +++++++++++++---- synapse/__init__.py | 2 +- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 8473d7d48c..a2b1023e78 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,3 +1,11 @@ +Changes in synapse v0.19.0-rc3 (2017-02-03) +=========================================== + +* Fix email push in pusher worker (PR #1875) +* Make presence.get_new_events a bit faster (PR #1876) +* Make /keys/changes a bit more performant (PR #1877) + + Changes in synapse v0.19.0-rc2 (2017-02-02) =========================================== @@ -10,18 +18,19 @@ Changes in synapse v0.19.0-rc1 (2017-02-02) Features: * Add support for specifying multiple bind addresses (PR #1709, #1712, #1795, - #1835) + #1835). Thanks to @kyrias! * Add /account/3pid/delete endpoint (PR #1714) * Add config option to configure the Riot URL used in notification emails (PR - #1811) -* Add username and password config options for turn server (PR #1832) + #1811). Thanks to @aperezdc! +* Add username and password config options for turn server (PR #1832). Thanks + to @xsteadfastx! * Implement device lists updates over federation (PR #1857, #1861, #1864) * Implement /keys/changes (PR #1869, #1872) Changes: -* Improve IPv6 support (PR #1696) +* Improve IPv6 support (PR #1696). Thanks to @kyrias and @glyph! * Log which files we saved attachments to in the media_repository (PR #1791) * Linearize updates to membership via PUT /state/ to better handle multiple joins (PR #1787) diff --git a/synapse/__init__.py b/synapse/__init__.py index d398799579..ceabdcb489 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.19.0-rc2" +__version__ = "0.19.0-rc3" -- cgit 1.4.1 From a597994fb6acff8d2f72855445c322d7fd685f3f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 2 Feb 2017 18:36:17 +0000 Subject: Measure new device list stuff --- synapse/handlers/device.py | 4 ++++ synapse/handlers/sync.py | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 158206aef6..8cb47ac417 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -17,6 +17,7 @@ from synapse.api import errors from synapse.api.constants import EventTypes from synapse.util import stringutils from synapse.util.async import Linearizer +from synapse.util.metrics import measure_func from synapse.types import get_domain_from_id, RoomStreamToken from twisted.internet import defer from ._base import BaseHandler @@ -193,6 +194,7 @@ class DeviceHandler(BaseHandler): else: raise + @measure_func("notify_device_update") @defer.inlineCallbacks def notify_device_update(self, user_id, device_ids): """Notify that a user's device(s) has changed. Pokes the notifier, and @@ -223,6 +225,7 @@ class DeviceHandler(BaseHandler): for host in hosts: self.federation_sender.send_device_messages(host) + @measure_func("device.get_user_ids_changed") @defer.inlineCallbacks def get_user_ids_changed(self, user_id, from_token): """Get list of users that have had the devices updated, or have newly @@ -276,6 +279,7 @@ class DeviceHandler(BaseHandler): # and those that actually still share a room with the user defer.returnValue(users_who_share_room & possibly_changed) + @measure_func("_incoming_device_list_update") @defer.inlineCallbacks def _incoming_device_list_update(self, origin, edu_content): user_id = edu_content["user_id"] diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 9a44de3f33..d7dcd1ce5b 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -16,7 +16,7 @@ from synapse.api.constants import Membership, EventTypes from synapse.util.async import concurrently_execute from synapse.util.logcontext import LoggingContext -from synapse.util.metrics import Measure +from synapse.util.metrics import Measure, measure_func from synapse.util.caches.response_cache import ResponseCache from synapse.push.clientformat import format_push_rules_for_user from synapse.visibility import filter_events_for_client @@ -561,6 +561,7 @@ class SyncHandler(object): next_batch=sync_result_builder.now_token, )) + @measure_func("_generate_sync_entry_for_device_list") @defer.inlineCallbacks def _generate_sync_entry_for_device_list(self, sync_result_builder): user_id = sync_result_builder.sync_config.user.to_string() -- cgit 1.4.1 From 38258a097657b34d4f521712ef869575a486610e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 2 Feb 2017 18:45:55 +0000 Subject: Bump cache sizes for common membership queries --- synapse/storage/roommember.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 249217e114..545d3d3a99 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -129,7 +129,7 @@ class RoomMemberStore(SQLBaseStore): with self._stream_id_gen.get_next() as stream_ordering: yield self.runInteraction("locally_reject_invite", f, stream_ordering) - @cached(max_entries=100000, iterable=True) + @cached(max_entries=500000, iterable=True) def get_users_in_room(self, room_id): def f(txn): @@ -274,13 +274,13 @@ class RoomMemberStore(SQLBaseStore): return rows - @cached(max_entries=5000) + @cached(max_entries=500000, iterable=True) def get_rooms_for_user(self, user_id): return self.get_rooms_for_user_where_membership_is( user_id, membership_list=[Membership.JOIN], ) - @cachedInlineCallbacks(max_entries=50000, cache_context=True, iterable=True) + @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True) def get_users_who_share_room_with_user(self, user_id, cache_context): """Returns the set of users who share a room with `user_id` """ -- cgit 1.4.1 From 84f600b2eec3d736ed4e667090dddf100965a77d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 2 Feb 2017 18:58:33 +0000 Subject: Bump changelog and version --- CHANGES.rst | 8 +++++++- synapse/__init__.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index a2b1023e78..d15973a8c6 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,4 +1,10 @@ -Changes in synapse v0.19.0-rc3 (2017-02-03) +Changes in synapse v0.19.0-rc4 (2017-02-02) +=========================================== + +* Bump cache sizes for common membership queries (PR #1879) + + +Changes in synapse v0.19.0-rc3 (2017-02-02) =========================================== * Fix email push in pusher worker (PR #1875) diff --git a/synapse/__init__.py b/synapse/__init__.py index ceabdcb489..950c4178b8 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.19.0-rc3" +__version__ = "0.19.0-rc4" -- cgit 1.4.1 From 38434a7fbbd614b8e37f3410db109893605cbee2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Sat, 4 Feb 2017 08:27:51 +0000 Subject: Bump changelog and version --- CHANGES.rst | 6 ++++++ synapse/__init__.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/CHANGES.rst b/CHANGES.rst index d15973a8c6..da241666d6 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,3 +1,9 @@ +Changes in synapse v0.19.0 (2017-02-04) +======================================= + +No changes since RC 4. + + Changes in synapse v0.19.0-rc4 (2017-02-02) =========================================== diff --git a/synapse/__init__.py b/synapse/__init__.py index 950c4178b8..d3f445be9c 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.19.0-rc4" +__version__ = "0.19.0" -- cgit 1.4.1