diff options
Diffstat (limited to 'synapse')
60 files changed, 1291 insertions, 690 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 6502a6be7b..34382e4e3c 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -26,6 +26,7 @@ import synapse.types from synapse import event_auth from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.errors import AuthError, Codes, ResourceLimitError +from synapse.config.server import is_threepid_reserved from synapse.types import UserID from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache from synapse.util.caches.lrucache import LruCache @@ -775,34 +776,56 @@ class Auth(object): ) @defer.inlineCallbacks - def check_auth_blocking(self, user_id=None): + def check_auth_blocking(self, user_id=None, threepid=None): """Checks if the user should be rejected for some external reason, such as monthly active user limiting or global disable flag Args: user_id(str|None): If present, checks for presence against existing MAU cohort + + threepid(dict|None): If present, checks for presence against configured + reserved threepid. Used in cases where the user is trying register + with a MAU blocked server, normally they would be rejected but their + threepid is on the reserved list. user_id and + threepid should never be set at the same time. """ + + # Never fail an auth check for the server notices users + # This can be a problem where event creation is prohibited due to blocking + if user_id == self.hs.config.server_notices_mxid: + return + if self.hs.config.hs_disabled: raise ResourceLimitError( 403, self.hs.config.hs_disabled_message, - errcode=Codes.RESOURCE_LIMIT_EXCEED, - admin_uri=self.hs.config.admin_uri, + errcode=Codes.RESOURCE_LIMIT_EXCEEDED, + admin_contact=self.hs.config.admin_contact, limit_type=self.hs.config.hs_disabled_limit_type ) if self.hs.config.limit_usage_by_mau is True: - # If the user is already part of the MAU cohort + assert not (user_id and threepid) + + # If the user is already part of the MAU cohort or a trial user if user_id: timestamp = yield self.store.user_last_seen_monthly_active(user_id) if timestamp: return + + is_trial = yield self.store.is_trial_user(user_id) + if is_trial: + return + elif threepid: + # If the user does not exist yet, but is signing up with a + # reserved threepid then pass auth check + if is_threepid_reserved(self.hs.config, threepid): + return # Else if there is no room in the MAU bucket, bail current_mau = yield self.store.get_monthly_active_count() if current_mau >= self.hs.config.max_mau_value: raise ResourceLimitError( 403, "Monthly Active User Limit Exceeded", - - admin_uri=self.hs.config.admin_uri, - errcode=Codes.RESOURCE_LIMIT_EXCEED, + admin_contact=self.hs.config.admin_contact, + errcode=Codes.RESOURCE_LIMIT_EXCEEDED, limit_type="monthly_active_user" ) diff --git a/synapse/api/constants.py b/synapse/api/constants.py index b0da506f6d..c2630c4c64 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -78,6 +78,7 @@ class EventTypes(object): Name = "m.room.name" ServerACL = "m.room.server_acl" + Pinned = "m.room.pinned_events" class RejectedReason(object): @@ -97,9 +98,17 @@ class ThirdPartyEntityKind(object): LOCATION = "location" +class RoomVersions(object): + V1 = "1" + VDH_TEST = "vdh-test-version" + + # the version we will give rooms which are created on this server -DEFAULT_ROOM_VERSION = "1" +DEFAULT_ROOM_VERSION = RoomVersions.V1 # vdh-test-version is a placeholder to get room versioning support working and tested # until we have a working v2. -KNOWN_ROOM_VERSIONS = {"1", "vdh-test-version"} +KNOWN_ROOM_VERSIONS = {RoomVersions.V1, RoomVersions.VDH_TEST} + +ServerNoticeMsgType = "m.server_notice" +ServerNoticeLimitReached = "m.server_notice.usage_limit_reached" diff --git a/synapse/api/errors.py b/synapse/api/errors.py index e26001ab12..2e7f98404d 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -56,7 +56,7 @@ class Codes(object): SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED" CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN" CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM" - RESOURCE_LIMIT_EXCEED = "M_RESOURCE_LIMIT_EXCEED" + RESOURCE_LIMIT_EXCEEDED = "M_RESOURCE_LIMIT_EXCEEDED" UNSUPPORTED_ROOM_VERSION = "M_UNSUPPORTED_ROOM_VERSION" INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION" @@ -238,11 +238,11 @@ class ResourceLimitError(SynapseError): """ def __init__( self, code, msg, - errcode=Codes.RESOURCE_LIMIT_EXCEED, - admin_uri=None, + errcode=Codes.RESOURCE_LIMIT_EXCEEDED, + admin_contact=None, limit_type=None, ): - self.admin_uri = admin_uri + self.admin_contact = admin_contact self.limit_type = limit_type super(ResourceLimitError, self).__init__(code, msg, errcode=errcode) @@ -250,7 +250,7 @@ class ResourceLimitError(SynapseError): return cs_error( self.msg, self.errcode, - admin_uri=self.admin_uri, + admin_contact=self.admin_contact, limit_type=self.limit_type ) diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 186831e118..a31a9a17e0 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -251,6 +251,7 @@ class FilterCollection(object): "include_leave", False ) self.event_fields = filter_json.get("event_fields", []) + self.event_format = filter_json.get("event_format", "client") def __repr__(self): return "<FilterCollection %s>" % (json.dumps(self._filter_json),) diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index 3348a8ec6d..86b5067400 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -51,10 +51,7 @@ class AppserviceSlaveStore( class AppserviceServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = AppserviceSlaveStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = AppserviceSlaveStore def _listen_http(self, listener_config): port = listener_config["port"] diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index ab79a45646..ce2b113dbb 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -74,10 +74,7 @@ class ClientReaderSlavedStore( class ClientReaderServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = ClientReaderSlavedStore def _listen_http(self, listener_config): port = listener_config["port"] diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index 03d39968a8..f98e456ea0 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -45,6 +45,11 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.rest.client.v1.profile import ( + ProfileAvatarURLRestServlet, + ProfileDisplaynameRestServlet, + ProfileRestServlet, +) from synapse.rest.client.v1.room import ( JoinRoomAliasServlet, RoomMembershipRestServlet, @@ -53,6 +58,7 @@ from synapse.rest.client.v1.room import ( ) from synapse.server import HomeServer from synapse.storage.engines import create_engine +from synapse.storage.user_directory import UserDirectoryStore from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole @@ -62,6 +68,9 @@ logger = logging.getLogger("synapse.app.event_creator") class EventCreatorSlavedStore( + # FIXME(#3714): We need to add UserDirectoryStore as we write directly + # rather than going via the correct worker. + UserDirectoryStore, DirectoryStore, SlavedTransactionStore, SlavedProfileStore, @@ -81,10 +90,7 @@ class EventCreatorSlavedStore( class EventCreatorServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = EventCreatorSlavedStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = EventCreatorSlavedStore def _listen_http(self, listener_config): port = listener_config["port"] @@ -101,6 +107,9 @@ class EventCreatorServer(HomeServer): RoomMembershipRestServlet(self).register(resource) RoomStateEventRestServlet(self).register(resource) JoinRoomAliasServlet(self).register(resource) + ProfileAvatarURLRestServlet(self).register(resource) + ProfileDisplaynameRestServlet(self).register(resource) + ProfileRestServlet(self).register(resource) resources.update({ "/_matrix/client/r0": resource, "/_matrix/client/unstable": resource, diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 7d8105778d..60f5973505 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -72,10 +72,7 @@ class FederationReaderSlavedStore( class FederationReaderServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = FederationReaderSlavedStore def _listen_http(self, listener_config): port = listener_config["port"] diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index d59007099b..60dd09aac3 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -78,10 +78,7 @@ class FederationSenderSlaveStore( class FederationSenderServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = FederationSenderSlaveStore def _listen_http(self, listener_config): port = listener_config["port"] diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index 8d484c1cd4..8c0b9c67b0 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -148,10 +148,7 @@ class FrontendProxySlavedStore( class FrontendProxyServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = FrontendProxySlavedStore def _listen_http(self, listener_config): port = listener_config["port"] diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 005921dcf7..3eb5b663de 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -62,7 +62,7 @@ from synapse.rest.key.v1.server_key_resource import LocalKey from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.server import HomeServer -from synapse.storage import are_all_users_on_domain +from synapse.storage import DataStore, are_all_users_on_domain from synapse.storage.engines import IncorrectDatabaseSetup, create_engine from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database from synapse.util.caches import CACHE_SIZE_FACTOR @@ -111,6 +111,8 @@ def build_resource_for_web_client(hs): class SynapseHomeServer(HomeServer): + DATASTORE_CLASS = DataStore + def _listener_http(self, config, listener_config): port = listener_config["port"] bind_addresses = listener_config["bind_addresses"] @@ -356,13 +358,13 @@ def setup(config_options): logger.info("Preparing database: %s...", config.database_config['name']) try: - db_conn = hs.get_db_conn(run_new_connection=False) - prepare_database(db_conn, database_engine, config=config) - database_engine.on_new_connection(db_conn) + with hs.get_db_conn(run_new_connection=False) as db_conn: + prepare_database(db_conn, database_engine, config=config) + database_engine.on_new_connection(db_conn) - hs.run_startup_checks(db_conn, database_engine) + hs.run_startup_checks(db_conn, database_engine) - db_conn.commit() + db_conn.commit() except UpgradeDatabaseException: sys.stderr.write( "\nFailed to upgrade database.\n" diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index fd1f6cbf7e..e3dbb3b4e6 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -60,10 +60,7 @@ class MediaRepositorySlavedStore( class MediaRepositoryServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = MediaRepositorySlavedStore def _listen_http(self, listener_config): port = listener_config["port"] diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index a4fc7e91fa..244c604de9 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -78,10 +78,7 @@ class PusherSlaveStore( class PusherServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = PusherSlaveStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = PusherSlaveStore def remove_pusher(self, app_id, push_key, user_id): self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 27e1998660..6662340797 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -249,10 +249,7 @@ class SynchrotronApplicationService(object): class SynchrotronServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = SynchrotronSlavedStore def _listen_http(self, listener_config): port = listener_config["port"] diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index 1388a42b59..96ffcaf073 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -94,10 +94,7 @@ class UserDirectorySlaveStore( class UserDirectoryServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = UserDirectorySlaveStore def _listen_http(self, listener_config): port = listener_config["port"] diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 6980e5890e..9ccc5a80fc 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import urllib + +from six.moves import urllib from prometheus_client import Counter @@ -98,7 +99,7 @@ class ApplicationServiceApi(SimpleHttpClient): def query_user(self, service, user_id): if service.url is None: defer.returnValue(False) - uri = service.url + ("/users/%s" % urllib.quote(user_id)) + uri = service.url + ("/users/%s" % urllib.parse.quote(user_id)) response = None try: response = yield self.get_json(uri, { @@ -119,7 +120,7 @@ class ApplicationServiceApi(SimpleHttpClient): def query_alias(self, service, alias): if service.url is None: defer.returnValue(False) - uri = service.url + ("/rooms/%s" % urllib.quote(alias)) + uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias)) response = None try: response = yield self.get_json(uri, { @@ -153,7 +154,7 @@ class ApplicationServiceApi(SimpleHttpClient): service.url, APP_SERVICE_PREFIX, kind, - urllib.quote(protocol) + urllib.parse.quote(protocol) ) try: response = yield self.get_json(uri, fields) @@ -188,7 +189,7 @@ class ApplicationServiceApi(SimpleHttpClient): uri = "%s%s/thirdparty/protocol/%s" % ( service.url, APP_SERVICE_PREFIX, - urllib.quote(protocol) + urllib.parse.quote(protocol) ) try: info = yield self.get_json(uri, {}) @@ -228,7 +229,7 @@ class ApplicationServiceApi(SimpleHttpClient): txn_id = str(txn_id) uri = service.url + ("/transactions/%s" % - urllib.quote(txn_id)) + urllib.parse.quote(txn_id)) try: yield self.put_json( uri=uri, diff --git a/synapse/config/server.py b/synapse/config/server.py index 68a612e594..c1c7c0105e 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -77,10 +77,15 @@ class ServerConfig(Config): self.max_mau_value = config.get( "max_mau_value", 0, ) + self.mau_limits_reserved_threepids = config.get( "mau_limit_reserved_threepids", [] ) + self.mau_trial_days = config.get( + "mau_trial_days", 0, + ) + # Options to disable HS self.hs_disabled = config.get("hs_disabled", False) self.hs_disabled_message = config.get("hs_disabled_message", "") @@ -88,7 +93,7 @@ class ServerConfig(Config): # Admin uri to direct users at should their instance become blocked # due to resource constraints - self.admin_uri = config.get("admin_uri", None) + self.admin_contact = config.get("admin_contact", None) # FIXME: federation_domain_whitelist needs sytests self.federation_domain_whitelist = None @@ -352,7 +357,7 @@ class ServerConfig(Config): # Homeserver blocking # # How to reach the server admin, used in ResourceLimitError - # admin_uri: 'mailto:admin@server.com' + # admin_contact: 'mailto:admin@server.com' # # Global block config # @@ -365,6 +370,7 @@ class ServerConfig(Config): # Enables monthly active user checking # limit_usage_by_mau: False # max_mau_value: 50 + # mau_trial_days: 2 # # Sometimes the server admin will want to ensure certain accounts are # never blocked by mau checking. These accounts are specified here. @@ -398,6 +404,23 @@ class ServerConfig(Config): " service on the given port.") +def is_threepid_reserved(config, threepid): + """Check the threepid against the reserved threepid config + Args: + config(ServerConfig) - to access server config attributes + threepid(dict) - The threepid to test for + + Returns: + boolean Is the threepid undertest reserved_user + """ + + for tp in config.mau_limits_reserved_threepids: + if (threepid['medium'] == tp['medium'] + and threepid['address'] == tp['address']): + return True + return False + + def read_gc_thresholds(thresholds): """Reads the three integer thresholds for garbage collection. Ensures that the thresholds are integers if thresholds are supplied. diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py index c20a32096a..e94400b8e2 100644 --- a/synapse/crypto/keyclient.py +++ b/synapse/crypto/keyclient.py @@ -18,7 +18,9 @@ import logging from canonicaljson import json from twisted.internet import defer, reactor +from twisted.internet.error import ConnectError from twisted.internet.protocol import Factory +from twisted.names.error import DomainError from twisted.web.http import HTTPClient from synapse.http.endpoint import matrix_federation_endpoint @@ -47,12 +49,14 @@ def fetch_server_key(server_name, tls_client_options_factory, path=KEY_API_V1): server_response, server_certificate = yield protocol.remote_key defer.returnValue((server_response, server_certificate)) except SynapseKeyClientError as e: - logger.exception("Error getting key for %r" % (server_name,)) + logger.warn("Error getting key for %r: %s", server_name, e) if e.status.startswith("4"): # Don't retry for 4xx responses. raise IOError("Cannot get key for %r" % server_name) + except (ConnectError, DomainError) as e: + logger.warn("Error getting key for %r: %s", server_name, e) except Exception as e: - logger.exception(e) + logger.exception("Error getting key for %r", server_name) raise IOError("Cannot get key for %r" % server_name) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 547c6aec80..dbee404ea7 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -838,9 +838,9 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): ) return self._send_edu( - edu_type=edu_type, - origin=origin, - content=content, + edu_type=edu_type, + origin=origin, + content=content, ) def on_query(self, query_type, args): @@ -851,6 +851,6 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): return handler(args) return self._get_query_client( - query_type=query_type, - args=args, + query_type=query_type, + args=args, ) diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 0bb468385d..6f5995735a 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -32,7 +32,7 @@ Events are replicated via a separate events stream. import logging from collections import namedtuple -from six import iteritems, itervalues +from six import iteritems from sortedcontainers import SortedDict @@ -117,7 +117,7 @@ class FederationRemoteSendQueue(object): user_ids = set( user_id - for uids in itervalues(self.presence_changed) + for uids in self.presence_changed.values() for user_id in uids ) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index b4fbe2c9d5..1054441ca5 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -106,7 +106,7 @@ class TransportLayerClient(object): dest (str) room_id (str) event_tuples (list) - limt (int) + limit (int) Returns: Deferred: Results in a dict received from the remote homeserver. diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 8cde9716ac..3972922ff9 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -261,10 +261,10 @@ class BaseFederationServlet(object): except NoAuthenticationError: origin = None if self.REQUIRE_AUTH: - logger.exception("authenticate_request failed") + logger.warn("authenticate_request failed: missing authentication") raise - except Exception: - logger.exception("authenticate_request failed") + except Exception as e: + logger.warn("authenticate_request failed: %s", e) raise if origin: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3dd107a285..3fa7a98445 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -291,8 +291,9 @@ class FederationHandler(BaseHandler): ev_ids, get_prev_content=False, check_redacted=False ) + room_version = yield self.store.get_room_version(pdu.room_id) state_map = yield resolve_events_with_factory( - state_groups, {pdu.event_id: pdu}, fetch + room_version, state_groups, {pdu.event_id: pdu}, fetch ) state = (yield self.store.get_events(state_map.values())).values() @@ -1828,7 +1829,10 @@ class FederationHandler(BaseHandler): (d.type, d.state_key): d for d in different_events if d }) - new_state = self.state_handler.resolve_events( + room_version = yield self.store.get_room_version(event.room_id) + + new_state = yield self.state_handler.resolve_events( + room_version, [list(local_view.values()), list(remote_view.values())], event ) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 9af2e8f869..75b8b7ce6a 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -32,12 +32,16 @@ from ._base import BaseHandler logger = logging.getLogger(__name__) -class ProfileHandler(BaseHandler): - PROFILE_UPDATE_MS = 60 * 1000 - PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 +class BaseProfileHandler(BaseHandler): + """Handles fetching and updating user profile information. + + BaseProfileHandler can be instantiated directly on workers and will + delegate to master when necessary. The master process should use the + subclass MasterProfileHandler + """ def __init__(self, hs): - super(ProfileHandler, self).__init__(hs) + super(BaseProfileHandler, self).__init__(hs) self.federation = hs.get_federation_client() hs.get_federation_registry().register_query_handler( @@ -46,11 +50,6 @@ class ProfileHandler(BaseHandler): self.user_directory_handler = hs.get_user_directory_handler() - if hs.config.worker_app is None: - self.clock.looping_call( - self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS, - ) - @defer.inlineCallbacks def get_profile(self, user_id): target_user = UserID.from_string(user_id) @@ -282,6 +281,20 @@ class ProfileHandler(BaseHandler): room_id, str(e.message) ) + +class MasterProfileHandler(BaseProfileHandler): + PROFILE_UPDATE_MS = 60 * 1000 + PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 + + def __init__(self, hs): + super(MasterProfileHandler, self).__init__(hs) + + assert hs.config.worker_app is None + + self.clock.looping_call( + self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS, + ) + def _start_update_remote_profile_cache(self): return run_as_background_process( "Update remote profile", self._update_remote_profile_cache, diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index f03ee1476b..1e53f2c635 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -125,6 +125,7 @@ class RegistrationHandler(BaseHandler): guest_access_token=None, make_guest=False, admin=False, + threepid=None, ): """Registers a new client on the server. @@ -145,7 +146,7 @@ class RegistrationHandler(BaseHandler): RegistrationError if there was a problem registering. """ - yield self.auth.check_auth_blocking() + yield self.auth.check_auth_blocking(threepid=threepid) password_hash = None if password: password_hash = yield self.auth_handler().hash(password) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index fb94b5d7d4..f643619047 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -344,6 +344,7 @@ class RoomMemberHandler(object): latest_event_ids = ( event_id for (event_id, _, _) in prev_events_and_hashes ) + current_state_ids = yield self.state_handler.get_current_state_ids( room_id, latest_event_ids=latest_event_ids, ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 648debc8aa..ef20c2296c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -745,9 +745,16 @@ class SyncHandler(object): state_ids = {} if lazy_load_members: if types: + # We're returning an incremental sync, with no "gap" since + # the previous sync, so normally there would be no state to return + # But we're lazy-loading, so the client might need some more + # member events to understand the events in this timeline. + # So we fish out all the member events corresponding to the + # timeline here, and then dedupe any redundant ones below. + state_ids = yield self.store.get_state_ids_for_event( batch.events[0].event_id, types=types, - filtered_types=filtered_types, + filtered_types=None, # we only want members! ) if lazy_load_members and not include_redundant_members: diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 37dda64587..d8413d6aa7 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -119,6 +119,8 @@ class UserDirectoryHandler(object): """Called to update index of our local user profiles when they change irrespective of any rooms the user may be in. """ + # FIXME(#3714): We should probably do this in the same worker as all + # the other changes. yield self.store.update_profile_in_user_dir( user_id, profile.display_name, profile.avatar_url, None, ) @@ -127,6 +129,8 @@ class UserDirectoryHandler(object): def handle_user_deactivated(self, user_id): """Called when a user ID is deactivated """ + # FIXME(#3714): We should probably do this in the same worker as all + # the other changes. yield self.store.remove_from_user_dir(user_id) yield self.store.remove_from_user_in_public_room(user_id) diff --git a/synapse/http/client.py b/synapse/http/client.py index ab4fbf59b2..4ba54fed05 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -13,24 +13,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import logging -import urllib -from six import StringIO +from six import text_type +from six.moves import urllib +import treq from canonicaljson import encode_canonical_json, json from prometheus_client import Counter from OpenSSL import SSL from OpenSSL.SSL import VERIFY_NONE -from twisted.internet import defer, protocol, reactor, ssl, task +from twisted.internet import defer, protocol, reactor, ssl from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.web._newclient import ResponseDone from twisted.web.client import ( Agent, BrowserLikeRedirectAgent, ContentDecoderAgent, - FileBodyProducer as TwistedFileBodyProducer, GzipDecoder, HTTPConnectionPool, PartialDownloadError, @@ -83,18 +84,20 @@ class SimpleHttpClient(object): if hs.config.user_agent_suffix: self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,) + self.user_agent = self.user_agent.encode('ascii') + @defer.inlineCallbacks - def request(self, method, uri, *args, **kwargs): + def request(self, method, uri, data=b'', headers=None): # A small wrapper around self.agent.request() so we can easily attach # counters to it outgoing_requests_counter.labels(method).inc() # log request but strip `access_token` (AS requests for example include this) - logger.info("Sending request %s %s", method, redact_uri(uri)) + logger.info("Sending request %s %s", method, redact_uri(uri.encode('ascii'))) try: - request_deferred = self.agent.request( - method, uri, *args, **kwargs + request_deferred = treq.request( + method, uri, agent=self.agent, data=data, headers=headers ) add_timeout_to_deferred( request_deferred, 60, self.hs.get_reactor(), @@ -105,14 +108,14 @@ class SimpleHttpClient(object): incoming_responses_counter.labels(method, response.code).inc() logger.info( "Received response to %s %s: %s", - method, redact_uri(uri), response.code + method, redact_uri(uri.encode('ascii')), response.code ) defer.returnValue(response) except Exception as e: incoming_responses_counter.labels(method, "ERR").inc() logger.info( "Error sending request to %s %s: %s %s", - method, redact_uri(uri), type(e).__name__, e.message + method, redact_uri(uri.encode('ascii')), type(e).__name__, e.args[0] ) raise @@ -137,7 +140,8 @@ class SimpleHttpClient(object): # TODO: Do we ever want to log message contents? logger.debug("post_urlencoded_get_json args: %s", args) - query_bytes = urllib.urlencode(encode_urlencode_args(args), True) + query_bytes = urllib.parse.urlencode( + encode_urlencode_args(args), True).encode("utf8") actual_headers = { b"Content-Type": [b"application/x-www-form-urlencoded"], @@ -148,15 +152,14 @@ class SimpleHttpClient(object): response = yield self.request( "POST", - uri.encode("ascii"), + uri, headers=Headers(actual_headers), - bodyProducer=FileBodyProducer(StringIO(query_bytes)) + data=query_bytes ) - body = yield make_deferred_yieldable(readBody(response)) - if 200 <= response.code < 300: - defer.returnValue(json.loads(body)) + body = yield make_deferred_yieldable(treq.json_content(response)) + defer.returnValue(body) else: raise HttpResponseException(response.code, response.phrase, body) @@ -191,9 +194,9 @@ class SimpleHttpClient(object): response = yield self.request( "POST", - uri.encode("ascii"), + uri, headers=Headers(actual_headers), - bodyProducer=FileBodyProducer(StringIO(json_str)) + data=json_str ) body = yield make_deferred_yieldable(readBody(response)) @@ -248,7 +251,7 @@ class SimpleHttpClient(object): ValueError: if the response was not JSON """ if len(args): - query_bytes = urllib.urlencode(args, True) + query_bytes = urllib.parse.urlencode(args, True) uri = "%s?%s" % (uri, query_bytes) json_str = encode_canonical_json(json_body) @@ -262,9 +265,9 @@ class SimpleHttpClient(object): response = yield self.request( "PUT", - uri.encode("ascii"), + uri, headers=Headers(actual_headers), - bodyProducer=FileBodyProducer(StringIO(json_str)) + data=json_str ) body = yield make_deferred_yieldable(readBody(response)) @@ -293,7 +296,7 @@ class SimpleHttpClient(object): HttpResponseException on a non-2xx HTTP response. """ if len(args): - query_bytes = urllib.urlencode(args, True) + query_bytes = urllib.parse.urlencode(args, True) uri = "%s?%s" % (uri, query_bytes) actual_headers = { @@ -304,7 +307,7 @@ class SimpleHttpClient(object): response = yield self.request( "GET", - uri.encode("ascii"), + uri, headers=Headers(actual_headers), ) @@ -339,7 +342,7 @@ class SimpleHttpClient(object): response = yield self.request( "GET", - url.encode("ascii"), + url, headers=Headers(actual_headers), ) @@ -434,12 +437,12 @@ class CaptchaServerHttpClient(SimpleHttpClient): @defer.inlineCallbacks def post_urlencoded_get_raw(self, url, args={}): - query_bytes = urllib.urlencode(encode_urlencode_args(args), True) + query_bytes = urllib.parse.urlencode(encode_urlencode_args(args), True) response = yield self.request( "POST", - url.encode("ascii"), - bodyProducer=FileBodyProducer(StringIO(query_bytes)), + url, + data=query_bytes, headers=Headers({ b"Content-Type": [b"application/x-www-form-urlencoded"], b"User-Agent": [self.user_agent], @@ -510,7 +513,7 @@ def encode_urlencode_args(args): def encode_urlencode_arg(arg): - if isinstance(arg, unicode): + if isinstance(arg, text_type): return arg.encode('utf-8') elif isinstance(arg, list): return [encode_urlencode_arg(i) for i in arg] @@ -542,26 +545,3 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory): def creatorForNetloc(self, hostname, port): return self - - -class FileBodyProducer(TwistedFileBodyProducer): - """Workaround for https://twistedmatrix.com/trac/ticket/8473 - - We override the pauseProducing and resumeProducing methods in twisted's - FileBodyProducer so that they do not raise exceptions if the task has - already completed. - """ - - def pauseProducing(self): - try: - super(FileBodyProducer, self).pauseProducing() - except task.TaskDone: - # task has already completed - pass - - def resumeProducing(self): - try: - super(FileBodyProducer, self).resumeProducing() - except task.NotPaused: - # task was not paused (probably because it had already completed) - pass diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 44b61e70a4..6a1fc8ca55 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -17,19 +17,19 @@ import cgi import logging import random import sys -import urllib -from six import string_types -from six.moves.urllib import parse as urlparse +from six import PY3, string_types +from six.moves import urllib -from canonicaljson import encode_canonical_json, json +import treq +from canonicaljson import encode_canonical_json from prometheus_client import Counter from signedjson.sign import sign_json from twisted.internet import defer, protocol, reactor from twisted.internet.error import DNSLookupError from twisted.web._newclient import ResponseDone -from twisted.web.client import Agent, HTTPConnectionPool, readBody +from twisted.web.client import Agent, HTTPConnectionPool from twisted.web.http_headers import Headers import synapse.metrics @@ -58,13 +58,18 @@ incoming_responses_counter = Counter("synapse_http_matrixfederationclient_respon MAX_LONG_RETRIES = 10 MAX_SHORT_RETRIES = 3 +if PY3: + MAXINT = sys.maxsize +else: + MAXINT = sys.maxint + class MatrixFederationEndpointFactory(object): def __init__(self, hs): self.tls_client_options_factory = hs.tls_client_options_factory def endpointForURI(self, uri): - destination = uri.netloc + destination = uri.netloc.decode('ascii') return matrix_federation_endpoint( reactor, destination, timeout=10, @@ -93,26 +98,32 @@ class MatrixFederationHttpClient(object): ) self.clock = hs.get_clock() self._store = hs.get_datastore() - self.version_string = hs.version_string + self.version_string = hs.version_string.encode('ascii') self._next_id = 1 def _create_url(self, destination, path_bytes, param_bytes, query_bytes): - return urlparse.urlunparse( - ("matrix", destination, path_bytes, param_bytes, query_bytes, "") + return urllib.parse.urlunparse( + (b"matrix", destination, path_bytes, param_bytes, query_bytes, b"") ) @defer.inlineCallbacks def _request(self, destination, method, path, - body_callback, headers_dict={}, param_bytes=b"", - query_bytes=b"", retry_on_dns_fail=True, + json=None, json_callback=None, + param_bytes=b"", + query=None, retry_on_dns_fail=True, timeout=None, long_retries=False, ignore_backoff=False, backoff_on_404=False): - """ Creates and sends a request to the given server + """ + Creates and sends a request to the given server. + Args: destination (str): The remote server to send the HTTP request to. method (str): HTTP method path (str): The HTTP path + json (dict or None): JSON to send in the body. + json_callback (func or None): A callback to generate the JSON. + query (dict or None): Query arguments. ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. backoff_on_404 (bool): Back off if we get a 404 @@ -133,7 +144,7 @@ class MatrixFederationHttpClient(object): failures, connection failures, SSL failures.) """ if ( - self.hs.config.federation_domain_whitelist and + self.hs.config.federation_domain_whitelist is not None and destination not in self.hs.config.federation_domain_whitelist ): raise FederationDeniedError(destination) @@ -146,22 +157,29 @@ class MatrixFederationHttpClient(object): ignore_backoff=ignore_backoff, ) - destination = destination.encode("ascii") + headers_dict = {} path_bytes = path.encode("ascii") - with limiter: - headers_dict[b"User-Agent"] = [self.version_string] - headers_dict[b"Host"] = [destination] + if query: + query_bytes = encode_query_args(query) + else: + query_bytes = b"" - url_bytes = self._create_url( - destination, path_bytes, param_bytes, query_bytes - ) + headers_dict = { + "User-Agent": [self.version_string], + "Host": [destination], + } + + with limiter: + url = self._create_url( + destination.encode("ascii"), path_bytes, param_bytes, query_bytes + ).decode('ascii') txn_id = "%s-O-%s" % (method, self._next_id) - self._next_id = (self._next_id + 1) % (sys.maxint - 1) + self._next_id = (self._next_id + 1) % (MAXINT - 1) outbound_logger.info( "{%s} [%s] Sending request: %s %s", - txn_id, destination, method, url_bytes + txn_id, destination, method, url ) # XXX: Would be much nicer to retry only at the transaction-layer @@ -171,23 +189,33 @@ class MatrixFederationHttpClient(object): else: retries_left = MAX_SHORT_RETRIES - http_url_bytes = urlparse.urlunparse( - ("", "", path_bytes, param_bytes, query_bytes, "") - ) + http_url = urllib.parse.urlunparse( + (b"", b"", path_bytes, param_bytes, query_bytes, b"") + ).decode('ascii') log_result = None try: while True: - producer = None - if body_callback: - producer = body_callback(method, http_url_bytes, headers_dict) - try: - request_deferred = self.agent.request( + if json_callback: + json = json_callback() + + if json: + data = encode_canonical_json(json) + headers_dict["Content-Type"] = ["application/json"] + self.sign_request( + destination, method, http_url, headers_dict, json + ) + else: + data = None + self.sign_request(destination, method, http_url, headers_dict) + + request_deferred = treq.request( method, - url_bytes, - Headers(headers_dict), - producer + url, + headers=Headers(headers_dict), + data=data, + agent=self.agent, ) add_timeout_to_deferred( request_deferred, @@ -218,7 +246,7 @@ class MatrixFederationHttpClient(object): txn_id, destination, method, - url_bytes, + url, _flatten_response_never_received(e), ) @@ -252,7 +280,7 @@ class MatrixFederationHttpClient(object): # :'( # Update transactions table? with logcontext.PreserveLoggingContext(): - body = yield readBody(response) + body = yield treq.content(response) raise HttpResponseException( response.code, response.phrase, body ) @@ -297,11 +325,11 @@ class MatrixFederationHttpClient(object): auth_headers = [] for key, sig in request["signatures"][self.server_name].items(): - auth_headers.append(bytes( + auth_headers.append(( "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % ( self.server_name, key, sig, - ) - )) + )).encode('ascii') + ) headers_dict[b"Authorization"] = auth_headers @@ -347,24 +375,14 @@ class MatrixFederationHttpClient(object): """ if not json_data_callback: - def json_data_callback(): - return data - - def body_callback(method, url_bytes, headers_dict): - json_data = json_data_callback() - self.sign_request( - destination, method, url_bytes, headers_dict, json_data - ) - producer = _JsonProducer(json_data) - return producer + json_data_callback = lambda: data response = yield self._request( destination, "PUT", path, - body_callback=body_callback, - headers_dict={"Content-Type": ["application/json"]}, - query_bytes=encode_query_args(args), + json_callback=json_data_callback, + query=args, long_retries=long_retries, timeout=timeout, ignore_backoff=ignore_backoff, @@ -376,8 +394,8 @@ class MatrixFederationHttpClient(object): check_content_type_is_json(response.headers) with logcontext.PreserveLoggingContext(): - body = yield readBody(response) - defer.returnValue(json.loads(body)) + body = yield treq.json_content(response) + defer.returnValue(body) @defer.inlineCallbacks def post_json(self, destination, path, data={}, long_retries=False, @@ -410,20 +428,12 @@ class MatrixFederationHttpClient(object): Fails with ``FederationDeniedError`` if this destination is not on our federation whitelist """ - - def body_callback(method, url_bytes, headers_dict): - self.sign_request( - destination, method, url_bytes, headers_dict, data - ) - return _JsonProducer(data) - response = yield self._request( destination, "POST", path, - query_bytes=encode_query_args(args), - body_callback=body_callback, - headers_dict={"Content-Type": ["application/json"]}, + query=args, + json=data, long_retries=long_retries, timeout=timeout, ignore_backoff=ignore_backoff, @@ -434,9 +444,9 @@ class MatrixFederationHttpClient(object): check_content_type_is_json(response.headers) with logcontext.PreserveLoggingContext(): - body = yield readBody(response) + body = yield treq.json_content(response) - defer.returnValue(json.loads(body)) + defer.returnValue(body) @defer.inlineCallbacks def get_json(self, destination, path, args=None, retry_on_dns_fail=True, @@ -471,16 +481,11 @@ class MatrixFederationHttpClient(object): logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail) - def body_callback(method, url_bytes, headers_dict): - self.sign_request(destination, method, url_bytes, headers_dict) - return None - response = yield self._request( destination, "GET", path, - query_bytes=encode_query_args(args), - body_callback=body_callback, + query=args, retry_on_dns_fail=retry_on_dns_fail, timeout=timeout, ignore_backoff=ignore_backoff, @@ -491,9 +496,9 @@ class MatrixFederationHttpClient(object): check_content_type_is_json(response.headers) with logcontext.PreserveLoggingContext(): - body = yield readBody(response) + body = yield treq.json_content(response) - defer.returnValue(json.loads(body)) + defer.returnValue(body) @defer.inlineCallbacks def delete_json(self, destination, path, long_retries=False, @@ -523,13 +528,11 @@ class MatrixFederationHttpClient(object): Fails with ``FederationDeniedError`` if this destination is not on our federation whitelist """ - response = yield self._request( destination, "DELETE", path, - query_bytes=encode_query_args(args), - headers_dict={"Content-Type": ["application/json"]}, + query=args, long_retries=long_retries, timeout=timeout, ignore_backoff=ignore_backoff, @@ -540,9 +543,9 @@ class MatrixFederationHttpClient(object): check_content_type_is_json(response.headers) with logcontext.PreserveLoggingContext(): - body = yield readBody(response) + body = yield treq.json_content(response) - defer.returnValue(json.loads(body)) + defer.returnValue(body) @defer.inlineCallbacks def get_file(self, destination, path, output_stream, args={}, @@ -569,26 +572,11 @@ class MatrixFederationHttpClient(object): Fails with ``FederationDeniedError`` if this destination is not on our federation whitelist """ - - encoded_args = {} - for k, vs in args.items(): - if isinstance(vs, string_types): - vs = [vs] - encoded_args[k] = [v.encode("UTF-8") for v in vs] - - query_bytes = urllib.urlencode(encoded_args, True) - logger.debug("Query bytes: %s Retry DNS: %s", query_bytes, retry_on_dns_fail) - - def body_callback(method, url_bytes, headers_dict): - self.sign_request(destination, method, url_bytes, headers_dict) - return None - response = yield self._request( destination, "GET", path, - query_bytes=query_bytes, - body_callback=body_callback, + query=args, retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff, ) @@ -639,30 +627,6 @@ def _readBodyToFile(response, stream, max_size): return d -class _JsonProducer(object): - """ Used by the twisted http client to create the HTTP body from json - """ - def __init__(self, jsn): - self.reset(jsn) - - def reset(self, jsn): - self.body = encode_canonical_json(jsn) - self.length = len(self.body) - - def startProducing(self, consumer): - consumer.write(self.body) - return defer.succeed(None) - - def pauseProducing(self): - pass - - def stopProducing(self): - pass - - def resumeProducing(self): - pass - - def _flatten_response_never_received(e): if hasattr(e, "reasons"): reasons = ", ".join( @@ -693,7 +657,7 @@ def check_content_type_is_json(headers): "No Content-Type header" ) - c_type = c_type[0] # only the first header + c_type = c_type[0].decode('ascii') # only the first header val, options = cgi.parse_header(c_type) if val != "application/json": raise RuntimeError( @@ -711,6 +675,6 @@ def encode_query_args(args): vs = [vs] encoded_args[k] = [v.encode("UTF-8") for v in vs] - query_bytes = urllib.urlencode(encoded_args, True) + query_bytes = urllib.parse.urlencode(encoded_args, True) - return query_bytes + return query_bytes.encode('utf8') diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py index 588e280571..72c2654678 100644 --- a/synapse/http/request_metrics.py +++ b/synapse/http/request_metrics.py @@ -15,6 +15,7 @@ # limitations under the License. import logging +import threading from prometheus_client.core import Counter, Histogram @@ -111,6 +112,9 @@ in_flight_requests_db_sched_duration = Counter( # The set of all in flight requests, set[RequestMetrics] _in_flight_requests = set() +# Protects the _in_flight_requests set from concurrent accesss +_in_flight_requests_lock = threading.Lock() + def _get_in_flight_counts(): """Returns a count of all in flight requests by (method, server_name) @@ -120,7 +124,8 @@ def _get_in_flight_counts(): """ # Cast to a list to prevent it changing while the Prometheus # thread is collecting metrics - reqs = list(_in_flight_requests) + with _in_flight_requests_lock: + reqs = list(_in_flight_requests) for rm in reqs: rm.update_metrics() @@ -154,10 +159,12 @@ class RequestMetrics(object): # to the "in flight" metrics. self._request_stats = self.start_context.get_resource_usage() - _in_flight_requests.add(self) + with _in_flight_requests_lock: + _in_flight_requests.add(self) def stop(self, time_sec, request): - _in_flight_requests.discard(self) + with _in_flight_requests_lock: + _in_flight_requests.discard(self) context = LoggingContext.current_context() diff --git a/synapse/http/site.py b/synapse/http/site.py index 88ed3714f9..f0828c6542 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -204,14 +204,14 @@ class SynapseRequest(Request): self.start_time = time.time() self.request_metrics = RequestMetrics() self.request_metrics.start( - self.start_time, name=servlet_name, method=self.method, + self.start_time, name=servlet_name, method=self.method.decode('ascii'), ) self.site.access_logger.info( "%s - %s - Received request: %s %s", self.getClientIP(), self.site.site_tag, - self.method, + self.method.decode('ascii'), self.get_redacted_uri() ) diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index ce678d5f75..167167be0a 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import threading + import six from prometheus_client.core import REGISTRY, Counter, GaugeMetricFamily @@ -78,6 +80,9 @@ _background_process_counts = dict() # type: dict[str, int] # of process descriptions that no longer have any active processes. _background_processes = dict() # type: dict[str, set[_BackgroundProcess]] +# A lock that covers the above dicts +_bg_metrics_lock = threading.Lock() + class _Collector(object): """A custom metrics collector for the background process metrics. @@ -92,7 +97,11 @@ class _Collector(object): labels=["name"], ) - for desc, processes in six.iteritems(_background_processes): + # We copy the dict so that it doesn't change from underneath us + with _bg_metrics_lock: + _background_processes_copy = dict(_background_processes) + + for desc, processes in six.iteritems(_background_processes_copy): background_process_in_flight_count.add_metric( (desc,), len(processes), ) @@ -167,19 +176,26 @@ def run_as_background_process(desc, func, *args, **kwargs): """ @defer.inlineCallbacks def run(): - count = _background_process_counts.get(desc, 0) - _background_process_counts[desc] = count + 1 + with _bg_metrics_lock: + count = _background_process_counts.get(desc, 0) + _background_process_counts[desc] = count + 1 + _background_process_start_count.labels(desc).inc() with LoggingContext(desc) as context: context.request = "%s-%i" % (desc, count) proc = _BackgroundProcess(desc, context) - _background_processes.setdefault(desc, set()).add(proc) + + with _bg_metrics_lock: + _background_processes.setdefault(desc, set()).add(proc) + try: yield func(*args, **kwargs) finally: proc.update_metrics() - _background_processes[desc].remove(proc) + + with _bg_metrics_lock: + _background_processes[desc].remove(proc) with PreserveLoggingContext(): return run() diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 987eec3ef2..6dd5179320 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -39,7 +39,8 @@ REQUIREMENTS = { "signedjson>=1.0.0": ["signedjson>=1.0.0"], "pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"], "service_identity>=1.0.0": ["service_identity>=1.0.0"], - "Twisted>=16.0.0": ["twisted>=16.0.0"], + "Twisted>=17.1.0": ["twisted>=17.1.0"], + "treq>=15.1": ["treq>=15.1"], # We use crypto.get_elliptic_curve which is only supported in >=0.15 "pyopenssl>=0.15": ["OpenSSL>=0.15"], @@ -78,6 +79,9 @@ CONDITIONAL_REQUIREMENTS = { "affinity": { "affinity": ["affinity"], }, + "postgres": { + "psycopg2>=2.6": ["psycopg2"] + } } diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index fcc1091760..976d98387d 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -531,7 +531,7 @@ class RoomEventServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id, event_id): - requester = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) event = yield self.event_handler.get_event(requester.user, room_id, event_id) time_now = self.clock.time_msec() diff --git a/synapse/rest/client/v1_only/register.py b/synapse/rest/client/v1_only/register.py index 5e99cffbcb..dadb376b02 100644 --- a/synapse/rest/client/v1_only/register.py +++ b/synapse/rest/client/v1_only/register.py @@ -23,6 +23,7 @@ from twisted.internet import defer import synapse.util.stringutils as stringutils from synapse.api.constants import LoginType from synapse.api.errors import Codes, SynapseError +from synapse.config.server import is_threepid_reserved from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request from synapse.rest.client.v1.base import ClientV1RestServlet from synapse.types import create_requester @@ -281,12 +282,20 @@ class RegisterRestServlet(ClientV1RestServlet): register_json["user"].encode("utf-8") if "user" in register_json else None ) + threepid = None + if session.get(LoginType.EMAIL_IDENTITY): + threepid = session["threepidCreds"] handler = self.handlers.registration_handler (user_id, token) = yield handler.register( localpart=desired_user_id, - password=password + password=password, + threepid=threepid, ) + # Necessary due to auth checks prior to the threepid being + # written to the db + if is_threepid_reserved(self.hs.config, threepid): + yield self.store.upsert_monthly_active_user(user_id) if session[LoginType.EMAIL_IDENTITY]: logger.debug("Binding emails %s to %s" % ( diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 372648cafd..37b32dd37b 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -53,7 +53,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): if not check_3pid_allowed(self.hs, "email", body['email']): raise SynapseError( - 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + 403, + "Your email domain is not authorized on this server", + Codes.THREEPID_DENIED, ) existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( @@ -89,7 +91,9 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet): if not check_3pid_allowed(self.hs, "msisdn", msisdn): raise SynapseError( - 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + 403, + "Account phone numbers are not authorized on this server", + Codes.THREEPID_DENIED, ) existingUid = yield self.datastore.get_user_id_by_threepid( @@ -241,7 +245,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): if not check_3pid_allowed(self.hs, "email", body['email']): raise SynapseError( - 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + 403, + "Your email domain is not authorized on this server", + Codes.THREEPID_DENIED, ) existingUid = yield self.datastore.get_user_id_by_threepid( @@ -276,7 +282,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): if not check_3pid_allowed(self.hs, "msisdn", msisdn): raise SynapseError( - 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + 403, + "Account phone numbers are not authorized on this server", + Codes.THREEPID_DENIED, ) existingUid = yield self.datastore.get_user_id_by_threepid( diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 2f64155d13..192f52e462 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -26,6 +26,7 @@ import synapse import synapse.types from synapse.api.constants import LoginType from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError +from synapse.config.server import is_threepid_reserved from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -74,7 +75,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): if not check_3pid_allowed(self.hs, "email", body['email']): raise SynapseError( - 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + 403, + "Your email domain is not authorized to register on this server", + Codes.THREEPID_DENIED, ) existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( @@ -114,7 +117,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): if not check_3pid_allowed(self.hs, "msisdn", msisdn): raise SynapseError( - 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + 403, + "Phone numbers are not authorized to register on this server", + Codes.THREEPID_DENIED, ) existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( @@ -372,7 +377,9 @@ class RegisterRestServlet(RestServlet): if not check_3pid_allowed(self.hs, medium, address): raise SynapseError( - 403, "Third party identifier is not allowed", + 403, + "Third party identifiers (email/phone numbers)" + + " are not authorized on this server", Codes.THREEPID_DENIED, ) @@ -395,12 +402,21 @@ class RegisterRestServlet(RestServlet): if desired_username is not None: desired_username = desired_username.lower() + threepid = None + if auth_result: + threepid = auth_result.get(LoginType.EMAIL_IDENTITY) + (registered_user_id, _) = yield self.registration_handler.register( localpart=desired_username, password=new_password, guest_access_token=guest_access_token, generate_token=False, + threepid=threepid, ) + # Necessary due to auth checks prior to the threepid being + # written to the db + if is_threepid_reserved(self.hs.config, threepid): + yield self.store.upsert_monthly_active_user(registered_user_id) # remember that we've now registered that user account, and with # what user ID (since the user may not have specified) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 1275baa1ba..263d8eb73e 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -25,6 +25,7 @@ from synapse.api.errors import SynapseError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection from synapse.events.utils import ( format_event_for_client_v2_without_room_id, + format_event_raw, serialize_event, ) from synapse.handlers.presence import format_user_presence_state @@ -175,17 +176,28 @@ class SyncRestServlet(RestServlet): @staticmethod def encode_response(time_now, sync_result, access_token_id, filter): + if filter.event_format == 'client': + event_formatter = format_event_for_client_v2_without_room_id + elif filter.event_format == 'federation': + event_formatter = format_event_raw + else: + raise Exception("Unknown event format %s" % (filter.event_format, )) + joined = SyncRestServlet.encode_joined( - sync_result.joined, time_now, access_token_id, filter.event_fields + sync_result.joined, time_now, access_token_id, + filter.event_fields, + event_formatter, ) invited = SyncRestServlet.encode_invited( sync_result.invited, time_now, access_token_id, + event_formatter, ) archived = SyncRestServlet.encode_archived( sync_result.archived, time_now, access_token_id, filter.event_fields, + event_formatter, ) return { @@ -228,7 +240,7 @@ class SyncRestServlet(RestServlet): } @staticmethod - def encode_joined(rooms, time_now, token_id, event_fields): + def encode_joined(rooms, time_now, token_id, event_fields, event_formatter): """ Encode the joined rooms in a sync result @@ -240,7 +252,9 @@ class SyncRestServlet(RestServlet): token_id(int): ID of the user's auth token - used for namespacing of transaction IDs event_fields(list<str>): List of event fields to include. If empty, - all fields will be returned. + all fields will be returned. + event_formatter (func[dict]): function to convert from federation format + to client format Returns: dict[str, dict[str, object]]: the joined rooms list, in our response format @@ -248,13 +262,14 @@ class SyncRestServlet(RestServlet): joined = {} for room in rooms: joined[room.room_id] = SyncRestServlet.encode_room( - room, time_now, token_id, only_fields=event_fields + room, time_now, token_id, joined=True, only_fields=event_fields, + event_formatter=event_formatter, ) return joined @staticmethod - def encode_invited(rooms, time_now, token_id): + def encode_invited(rooms, time_now, token_id, event_formatter): """ Encode the invited rooms in a sync result @@ -264,7 +279,9 @@ class SyncRestServlet(RestServlet): time_now(int): current time - used as a baseline for age calculations token_id(int): ID of the user's auth token - used for namespacing - of transaction IDs + of transaction IDs + event_formatter (func[dict]): function to convert from federation format + to client format Returns: dict[str, dict[str, object]]: the invited rooms list, in our @@ -274,7 +291,7 @@ class SyncRestServlet(RestServlet): for room in rooms: invite = serialize_event( room.invite, time_now, token_id=token_id, - event_format=format_event_for_client_v2_without_room_id, + event_format=event_formatter, is_invite=True, ) unsigned = dict(invite.get("unsigned", {})) @@ -288,7 +305,7 @@ class SyncRestServlet(RestServlet): return invited @staticmethod - def encode_archived(rooms, time_now, token_id, event_fields): + def encode_archived(rooms, time_now, token_id, event_fields, event_formatter): """ Encode the archived rooms in a sync result @@ -300,7 +317,9 @@ class SyncRestServlet(RestServlet): token_id(int): ID of the user's auth token - used for namespacing of transaction IDs event_fields(list<str>): List of event fields to include. If empty, - all fields will be returned. + all fields will be returned. + event_formatter (func[dict]): function to convert from federation format + to client format Returns: dict[str, dict[str, object]]: The invited rooms list, in our response format @@ -308,13 +327,18 @@ class SyncRestServlet(RestServlet): joined = {} for room in rooms: joined[room.room_id] = SyncRestServlet.encode_room( - room, time_now, token_id, joined=False, only_fields=event_fields + room, time_now, token_id, joined=False, + only_fields=event_fields, + event_formatter=event_formatter, ) return joined @staticmethod - def encode_room(room, time_now, token_id, joined=True, only_fields=None): + def encode_room( + room, time_now, token_id, joined, + only_fields, event_formatter, + ): """ Args: room (JoinedSyncResult|ArchivedSyncResult): sync result for a @@ -326,14 +350,15 @@ class SyncRestServlet(RestServlet): joined (bool): True if the user is joined to this room - will mean we handle ephemeral events only_fields(list<str>): Optional. The list of event fields to include. + event_formatter (func[dict]): function to convert from federation format + to client format Returns: dict[str, object]: the room, encoded in our response format """ def serialize(event): - # TODO(mjark): Respect formatting requirements in the filter. return serialize_event( event, time_now, token_id=token_id, - event_format=format_event_for_client_v2_without_room_id, + event_format=event_formatter, only_event_fields=only_fields, ) diff --git a/synapse/server.py b/synapse/server.py index 26228d8c72..938a05f9dc 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -19,6 +19,7 @@ # partial one for unit test mocking. # Imports required for the default HomeServer() implementation +import abc import logging from twisted.enterprise import adbapi @@ -56,7 +57,7 @@ from synapse.handlers.initial_sync import InitialSyncHandler from synapse.handlers.message import EventCreationHandler, MessageHandler from synapse.handlers.pagination import PaginationHandler from synapse.handlers.presence import PresenceHandler -from synapse.handlers.profile import ProfileHandler +from synapse.handlers.profile import BaseProfileHandler, MasterProfileHandler from synapse.handlers.read_marker import ReadMarkerHandler from synapse.handlers.receipts import ReceiptsHandler from synapse.handlers.room import RoomContextHandler, RoomCreationHandler @@ -81,7 +82,6 @@ from synapse.server_notices.server_notices_manager import ServerNoticesManager from synapse.server_notices.server_notices_sender import ServerNoticesSender from synapse.server_notices.worker_server_notices_sender import WorkerServerNoticesSender from synapse.state import StateHandler, StateResolutionHandler -from synapse.storage import DataStore from synapse.streams.events import EventSources from synapse.util import Clock from synapse.util.distributor import Distributor @@ -111,6 +111,8 @@ class HomeServer(object): config (synapse.config.homeserver.HomeserverConfig): """ + __metaclass__ = abc.ABCMeta + DEPENDENCIES = [ 'http_client', 'db_pool', @@ -172,6 +174,11 @@ class HomeServer(object): 'room_context_handler', ] + # This is overridden in derived application classes + # (such as synapse.app.homeserver.SynapseHomeServer) and gives the class to be + # instantiated during setup() for future return by get_datastore() + DATASTORE_CLASS = abc.abstractproperty() + def __init__(self, hostname, reactor=None, **kwargs): """ Args: @@ -188,13 +195,16 @@ class HomeServer(object): self.distributor = Distributor() self.ratelimiter = Ratelimiter() + self.datastore = None + # Other kwargs are explicit dependencies for depname in kwargs: setattr(self, depname, kwargs[depname]) def setup(self): logger.info("Setting up.") - self.datastore = DataStore(self.get_db_conn(), self) + with self.get_db_conn() as conn: + self.datastore = self.DATASTORE_CLASS(conn, self) logger.info("Finished setting up.") def get_reactor(self): @@ -308,7 +318,10 @@ class HomeServer(object): return InitialSyncHandler(self) def build_profile_handler(self): - return ProfileHandler(self) + if self.config.worker_app: + return BaseProfileHandler(self) + else: + return MasterProfileHandler(self) def build_event_creation_handler(self): return EventCreationHandler(self) diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py new file mode 100644 index 0000000000..af15cba0ee --- /dev/null +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -0,0 +1,203 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from six import iteritems + +from twisted.internet import defer + +from synapse.api.constants import ( + EventTypes, + ServerNoticeLimitReached, + ServerNoticeMsgType, +) +from synapse.api.errors import AuthError, ResourceLimitError, SynapseError +from synapse.server_notices.server_notices_manager import SERVER_NOTICE_ROOM_TAG + +logger = logging.getLogger(__name__) + + +class ResourceLimitsServerNotices(object): + """ Keeps track of whether the server has reached it's resource limit and + ensures that the client is kept up to date. + """ + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ + self._server_notices_manager = hs.get_server_notices_manager() + self._store = hs.get_datastore() + self._auth = hs.get_auth() + self._config = hs.config + self._resouce_limited = False + self._message_handler = hs.get_message_handler() + self._state = hs.get_state_handler() + + self._notifier = hs.get_notifier() + + @defer.inlineCallbacks + def maybe_send_server_notice_to_user(self, user_id): + """Check if we need to send a notice to this user, this will be true in + two cases. + 1. The server has reached its limit does not reflect this + 2. The room state indicates that the server has reached its limit when + actually the server is fine + + Args: + user_id (str): user to check + + Returns: + Deferred + """ + if self._config.hs_disabled is True: + return + + if self._config.limit_usage_by_mau is False: + return + + if not self._server_notices_manager.is_enabled(): + # Don't try and send server notices unles they've been enabled + return + + timestamp = yield self._store.user_last_seen_monthly_active(user_id) + if timestamp is None: + # This user will be blocked from receiving the notice anyway. + # In practice, not sure we can ever get here + return + + # Determine current state of room + + room_id = yield self._server_notices_manager.get_notice_room_for_user(user_id) + + if not room_id: + logger.warn("Failed to get server notices room") + return + + yield self._check_and_set_tags(user_id, room_id) + currently_blocked, ref_events = yield self._is_room_currently_blocked(room_id) + + try: + # Normally should always pass in user_id if you have it, but in + # this case are checking what would happen to other users if they + # were to arrive. + try: + yield self._auth.check_auth_blocking() + is_auth_blocking = False + except ResourceLimitError as e: + is_auth_blocking = True + event_content = e.msg + event_limit_type = e.limit_type + + if currently_blocked and not is_auth_blocking: + # Room is notifying of a block, when it ought not to be. + # Remove block notification + content = { + "pinned": ref_events + } + yield self._server_notices_manager.send_notice( + user_id, content, EventTypes.Pinned, '', + ) + + elif not currently_blocked and is_auth_blocking: + # Room is not notifying of a block, when it ought to be. + # Add block notification + content = { + 'body': event_content, + 'msgtype': ServerNoticeMsgType, + 'server_notice_type': ServerNoticeLimitReached, + 'admin_contact': self._config.admin_contact, + 'limit_type': event_limit_type + } + event = yield self._server_notices_manager.send_notice( + user_id, content, EventTypes.Message, + ) + + content = { + "pinned": [ + event.event_id, + ] + } + yield self._server_notices_manager.send_notice( + user_id, content, EventTypes.Pinned, '', + ) + + except SynapseError as e: + logger.error("Error sending resource limits server notice: %s", e) + + @defer.inlineCallbacks + def _check_and_set_tags(self, user_id, room_id): + """ + Since server notices rooms were originally not with tags, + important to check that tags have been set correctly + Args: + user_id(str): the user in question + room_id(str): the server notices room for that user + """ + tags = yield self._store.get_tags_for_room(user_id, room_id) + need_to_set_tag = True + if tags: + if SERVER_NOTICE_ROOM_TAG in tags: + # tag already present, nothing to do here + need_to_set_tag = False + if need_to_set_tag: + max_id = yield self._store.add_tag_to_room( + user_id, room_id, SERVER_NOTICE_ROOM_TAG, {} + ) + self._notifier.on_new_event( + "account_data_key", max_id, users=[user_id] + ) + + @defer.inlineCallbacks + def _is_room_currently_blocked(self, room_id): + """ + Determines if the room is currently blocked + + Args: + room_id(str): The room id of the server notices room + + Returns: + + bool: Is the room currently blocked + list: The list of pinned events that are unrelated to limit blocking + This list can be used as a convenience in the case where the block + is to be lifted and the remaining pinned event references need to be + preserved + """ + currently_blocked = False + pinned_state_event = None + try: + pinned_state_event = yield self._state.get_current_state( + room_id, event_type=EventTypes.Pinned + ) + except AuthError: + # The user has yet to join the server notices room + pass + + referenced_events = [] + if pinned_state_event is not None: + referenced_events = list(pinned_state_event.content.get('pinned', [])) + + events = yield self._store.get_events(referenced_events) + for event_id, event in iteritems(events): + if event.type != EventTypes.Message: + continue + if event.content.get("msgtype") == ServerNoticeMsgType: + currently_blocked = True + # remove event in case we need to disable blocking later on. + if event_id in referenced_events: + referenced_events.remove(event.event_id) + + defer.returnValue((currently_blocked, referenced_events)) diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index a26deace53..c5cc6d728e 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -22,6 +22,8 @@ from synapse.util.caches.descriptors import cachedInlineCallbacks logger = logging.getLogger(__name__) +SERVER_NOTICE_ROOM_TAG = "m.server_notice" + class ServerNoticesManager(object): def __init__(self, hs): @@ -37,6 +39,8 @@ class ServerNoticesManager(object): self._event_creation_handler = hs.get_event_creation_handler() self._is_mine_id = hs.is_mine_id + self._notifier = hs.get_notifier() + def is_enabled(self): """Checks if server notices are enabled on this server. @@ -46,7 +50,10 @@ class ServerNoticesManager(object): return self._config.server_notices_mxid is not None @defer.inlineCallbacks - def send_notice(self, user_id, event_content): + def send_notice( + self, user_id, event_content, + type=EventTypes.Message, state_key=None + ): """Send a notice to the given user Creates the server notices room, if none exists. @@ -54,9 +61,11 @@ class ServerNoticesManager(object): Args: user_id (str): mxid of user to send event to. event_content (dict): content of event to send + type(EventTypes): type of event + is_state_event(bool): Is the event a state event Returns: - Deferred[None] + Deferred[FrozenEvent] """ room_id = yield self.get_notice_room_for_user(user_id) @@ -65,15 +74,20 @@ class ServerNoticesManager(object): logger.info("Sending server notice to %s", user_id) - yield self._event_creation_handler.create_and_send_nonmember_event( - requester, { - "type": EventTypes.Message, - "room_id": room_id, - "sender": system_mxid, - "content": event_content, - }, - ratelimit=False, + event_dict = { + "type": type, + "room_id": room_id, + "sender": system_mxid, + "content": event_content, + } + + if state_key is not None: + event_dict['state_key'] = state_key + + res = yield self._event_creation_handler.create_and_send_nonmember_event( + requester, event_dict, ratelimit=False, ) + defer.returnValue(res) @cachedInlineCallbacks() def get_notice_room_for_user(self, user_id): @@ -142,5 +156,12 @@ class ServerNoticesManager(object): ) room_id = info['room_id'] + max_id = yield self._store.add_tag_to_room( + user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}, + ) + self._notifier.on_new_event( + "account_data_key", max_id, users=[user_id] + ) + logger.info("Created server notices room %s for %s", room_id, user_id) defer.returnValue(room_id) diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py index 5d23965f34..6121b2f267 100644 --- a/synapse/server_notices/server_notices_sender.py +++ b/synapse/server_notices/server_notices_sender.py @@ -12,7 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + from synapse.server_notices.consent_server_notices import ConsentServerNotices +from synapse.server_notices.resource_limits_server_notices import ( + ResourceLimitsServerNotices, +) class ServerNoticesSender(object): @@ -25,34 +30,34 @@ class ServerNoticesSender(object): Args: hs (synapse.server.HomeServer): """ - # todo: it would be nice to make this more dynamic - self._consent_server_notices = ConsentServerNotices(hs) + self._server_notices = ( + ConsentServerNotices(hs), + ResourceLimitsServerNotices(hs) + ) + @defer.inlineCallbacks def on_user_syncing(self, user_id): """Called when the user performs a sync operation. Args: user_id (str): mxid of user who synced - - Returns: - Deferred """ - return self._consent_server_notices.maybe_send_server_notice_to_user( - user_id, - ) + for sn in self._server_notices: + yield sn.maybe_send_server_notice_to_user( + user_id, + ) + @defer.inlineCallbacks def on_user_ip(self, user_id): """Called on the master when a worker process saw a client request. Args: user_id (str): mxid - - Returns: - Deferred """ # The synchrotrons use a stubbed version of ServerNoticesSender, so # we check for notices to send to the user in on_user_ip as well as # in on_user_syncing - return self._consent_server_notices.maybe_send_server_notice_to_user( - user_id, - ) + for sn in self._server_notices: + yield sn.maybe_send_server_notice_to_user( + user_id, + ) diff --git a/synapse/state.py b/synapse/state/__init__.py index 0f2bedb694..d7ae22a661 100644 --- a/synapse/state.py +++ b/synapse/state/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,21 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import hashlib import logging from collections import namedtuple -from six import iteritems, iterkeys, itervalues +from six import iteritems, itervalues from frozendict import frozendict from twisted.internet import defer -from synapse import event_auth -from synapse.api.constants import EventTypes -from synapse.api.errors import AuthError +from synapse.api.constants import EventTypes, RoomVersions from synapse.events.snapshot import EventContext +from synapse.state import v1 from synapse.util.async_helpers import Linearizer from synapse.util.caches import get_cache_factor_for from synapse.util.caches.expiringcache import ExpiringCache @@ -264,6 +262,7 @@ class StateHandler(object): defer.returnValue(context) logger.debug("calling resolve_state_groups from compute_event_context") + entry = yield self.resolve_state_groups_for_events( event.room_id, [e for e, _ in event.prev_events], ) @@ -338,8 +337,11 @@ class StateHandler(object): event, resolves conflicts between them and returns them. Args: - room_id (str): - event_ids (list[str]): + room_id (str) + event_ids (list[str]) + explicit_room_version (str|None): If set uses the the given room + version to choose the resolution algorithm. If None, then + checks the database for room version. Returns: Deferred[_StateCacheEntry]: resolved state @@ -353,7 +355,12 @@ class StateHandler(object): room_id, event_ids ) - if len(state_groups_ids) == 1: + if len(state_groups_ids) == 0: + defer.returnValue(_StateCacheEntry( + state={}, + state_group=None, + )) + elif len(state_groups_ids) == 1: name, state_list = list(state_groups_ids.items()).pop() prev_group, delta_ids = yield self.store.get_state_group_delta(name) @@ -365,8 +372,11 @@ class StateHandler(object): delta_ids=delta_ids, )) + room_version = yield self.store.get_room_version(room_id) + result = yield self._state_resolution_handler.resolve_state_groups( - room_id, state_groups_ids, None, self._state_map_factory, + room_id, room_version, state_groups_ids, None, + self._state_map_factory, ) defer.returnValue(result) @@ -375,7 +385,8 @@ class StateHandler(object): ev_ids, get_prev_content=False, check_redacted=False, ) - def resolve_events(self, state_sets, event): + @defer.inlineCallbacks + def resolve_events(self, room_version, state_sets, event): logger.info( "Resolving state for %s with %d groups", event.room_id, len(state_sets) ) @@ -391,13 +402,17 @@ class StateHandler(object): } with Measure(self.clock, "state._resolve_events"): - new_state = resolve_events_with_state_map(state_set_ids, state_map) + new_state = yield resolve_events_with_factory( + room_version, state_set_ids, + event_map=state_map, + state_map_factory=self._state_map_factory + ) new_state = { key: state_map[ev_id] for key, ev_id in iteritems(new_state) } - return new_state + defer.returnValue(new_state) class StateResolutionHandler(object): @@ -430,7 +445,7 @@ class StateResolutionHandler(object): @defer.inlineCallbacks @log_function def resolve_state_groups( - self, room_id, state_groups_ids, event_map, state_map_factory, + self, room_id, room_version, state_groups_ids, event_map, state_map_factory, ): """Resolves conflicts between a set of state groups @@ -439,6 +454,7 @@ class StateResolutionHandler(object): Args: room_id (str): room we are resolving for (used for logging) + room_version (str): version of the room state_groups_ids (dict[int, dict[(str, str), str]]): map from state group id to the state in that state group (where 'state' is a map from state key to event id) @@ -492,6 +508,7 @@ class StateResolutionHandler(object): logger.info("Resolving conflicted state for %r", room_id) with Measure(self.clock, "state._resolve_events"): new_state = yield resolve_events_with_factory( + room_version, list(itervalues(state_groups_ids)), event_map=event_map, state_map_factory=state_map_factory, @@ -575,94 +592,11 @@ def _make_state_cache_entry( ) -def _ordered_events(events): - def key_func(e): - return -int(e.depth), hashlib.sha1(e.event_id.encode('ascii')).hexdigest() - - return sorted(events, key=key_func) - - -def resolve_events_with_state_map(state_sets, state_map): - """ - Args: - state_sets(list): List of dicts of (type, state_key) -> event_id, - which are the different state groups to resolve. - state_map(dict): a dict from event_id to event, for all events in - state_sets. - - Returns - dict[(str, str), str]: - a map from (type, state_key) to event_id. +def resolve_events_with_factory(room_version, state_sets, event_map, state_map_factory): """ - if len(state_sets) == 1: - return state_sets[0] - - unconflicted_state, conflicted_state = _seperate( - state_sets, - ) - - auth_events = _create_auth_events_from_maps( - unconflicted_state, conflicted_state, state_map - ) - - return _resolve_with_state( - unconflicted_state, conflicted_state, auth_events, state_map - ) - - -def _seperate(state_sets): - """Takes the state_sets and figures out which keys are conflicted and - which aren't. i.e., which have multiple different event_ids associated - with them in different state sets. - Args: - state_sets(iterable[dict[(str, str), str]]): - List of dicts of (type, state_key) -> event_id, which are the - different state groups to resolve. - - Returns: - (dict[(str, str), str], dict[(str, str), set[str]]): - A tuple of (unconflicted_state, conflicted_state), where: - - unconflicted_state is a dict mapping (type, state_key)->event_id - for unconflicted state keys. - - conflicted_state is a dict mapping (type, state_key) to a set of - event ids for conflicted state keys. - """ - state_set_iterator = iter(state_sets) - unconflicted_state = dict(next(state_set_iterator)) - conflicted_state = {} - - for state_set in state_set_iterator: - for key, value in iteritems(state_set): - # Check if there is an unconflicted entry for the state key. - unconflicted_value = unconflicted_state.get(key) - if unconflicted_value is None: - # There isn't an unconflicted entry so check if there is a - # conflicted entry. - ls = conflicted_state.get(key) - if ls is None: - # There wasn't a conflicted entry so haven't seen this key before. - # Therefore it isn't conflicted yet. - unconflicted_state[key] = value - else: - # This key is already conflicted, add our value to the conflict set. - ls.add(value) - elif unconflicted_value != value: - # If the unconflicted value is not the same as our value then we - # have a new conflict. So move the key from the unconflicted_state - # to the conflicted state. - conflicted_state[key] = {value, unconflicted_value} - unconflicted_state.pop(key, None) - - return unconflicted_state, conflicted_state - + room_version(str): Version of the room -@defer.inlineCallbacks -def resolve_events_with_factory(state_sets, event_map, state_map_factory): - """ - Args: state_sets(list): List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. @@ -682,185 +616,13 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory): Deferred[dict[(str, str), str]]: a map from (type, state_key) to event_id. """ - if len(state_sets) == 1: - defer.returnValue(state_sets[0]) - - unconflicted_state, conflicted_state = _seperate( - state_sets, - ) - - needed_events = set( - event_id - for event_ids in itervalues(conflicted_state) - for event_id in event_ids - ) - if event_map is not None: - needed_events -= set(iterkeys(event_map)) - - logger.info("Asking for %d conflicted events", len(needed_events)) - - # dict[str, FrozenEvent]: a map from state event id to event. Only includes - # the state events which are in conflict (and those in event_map) - state_map = yield state_map_factory(needed_events) - if event_map is not None: - state_map.update(event_map) - - # get the ids of the auth events which allow us to authenticate the - # conflicted state, picking only from the unconflicting state. - # - # dict[(str, str), str]: a map from state key to event id - auth_events = _create_auth_events_from_maps( - unconflicted_state, conflicted_state, state_map - ) - - new_needed_events = set(itervalues(auth_events)) - new_needed_events -= needed_events - if event_map is not None: - new_needed_events -= set(iterkeys(event_map)) - - logger.info("Asking for %d auth events", len(new_needed_events)) - - state_map_new = yield state_map_factory(new_needed_events) - state_map.update(state_map_new) - - defer.returnValue(_resolve_with_state( - unconflicted_state, conflicted_state, auth_events, state_map - )) - - -def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map): - auth_events = {} - for event_ids in itervalues(conflicted_state): - for event_id in event_ids: - if event_id in state_map: - keys = event_auth.auth_types_for_event(state_map[event_id]) - for key in keys: - if key not in auth_events: - event_id = unconflicted_state.get(key, None) - if event_id: - auth_events[key] = event_id - return auth_events - - -def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event_ids, - state_map): - conflicted_state = {} - for key, event_ids in iteritems(conflicted_state_ids): - events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map] - if len(events) > 1: - conflicted_state[key] = events - elif len(events) == 1: - unconflicted_state_ids[key] = events[0].event_id - - auth_events = { - key: state_map[ev_id] - for key, ev_id in iteritems(auth_event_ids) - if ev_id in state_map - } - - try: - resolved_state = _resolve_state_events( - conflicted_state, auth_events + if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,): + return v1.resolve_events_with_factory( + state_sets, event_map, state_map_factory, + ) + else: + # This should only happen if we added a version but forgot to add it to + # the list above. + raise Exception( + "No state resolution algorithm defined for version %r" % (room_version,) ) - except Exception: - logger.exception("Failed to resolve state") - raise - - new_state = unconflicted_state_ids - for key, event in iteritems(resolved_state): - new_state[key] = event.event_id - - return new_state - - -def _resolve_state_events(conflicted_state, auth_events): - """ This is where we actually decide which of the conflicted state to - use. - - We resolve conflicts in the following order: - 1. power levels - 2. join rules - 3. memberships - 4. other events. - """ - resolved_state = {} - if POWER_KEY in conflicted_state: - events = conflicted_state[POWER_KEY] - logger.debug("Resolving conflicted power levels %r", events) - resolved_state[POWER_KEY] = _resolve_auth_events( - events, auth_events) - - auth_events.update(resolved_state) - - for key, events in iteritems(conflicted_state): - if key[0] == EventTypes.JoinRules: - logger.debug("Resolving conflicted join rules %r", events) - resolved_state[key] = _resolve_auth_events( - events, - auth_events - ) - - auth_events.update(resolved_state) - - for key, events in iteritems(conflicted_state): - if key[0] == EventTypes.Member: - logger.debug("Resolving conflicted member lists %r", events) - resolved_state[key] = _resolve_auth_events( - events, - auth_events - ) - - auth_events.update(resolved_state) - - for key, events in iteritems(conflicted_state): - if key not in resolved_state: - logger.debug("Resolving conflicted state %r:%r", key, events) - resolved_state[key] = _resolve_normal_events( - events, auth_events - ) - - return resolved_state - - -def _resolve_auth_events(events, auth_events): - reverse = [i for i in reversed(_ordered_events(events))] - - auth_keys = set( - key - for event in events - for key in event_auth.auth_types_for_event(event) - ) - - new_auth_events = {} - for key in auth_keys: - auth_event = auth_events.get(key, None) - if auth_event: - new_auth_events[key] = auth_event - - auth_events = new_auth_events - - prev_event = reverse[0] - for event in reverse[1:]: - auth_events[(prev_event.type, prev_event.state_key)] = prev_event - try: - # The signatures have already been checked at this point - event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) - prev_event = event - except AuthError: - return prev_event - - return event - - -def _resolve_normal_events(events, auth_events): - for event in _ordered_events(events): - try: - # The signatures have already been checked at this point - event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) - return event - except AuthError: - pass - - # Use the last event (the one with the least depth) if they all fail - # the auth check. - return event diff --git a/synapse/state/v1.py b/synapse/state/v1.py new file mode 100644 index 0000000000..c95477d318 --- /dev/null +++ b/synapse/state/v1.py @@ -0,0 +1,293 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import logging + +from six import iteritems, iterkeys, itervalues + +from twisted.internet import defer + +from synapse import event_auth +from synapse.api.constants import EventTypes +from synapse.api.errors import AuthError + +logger = logging.getLogger(__name__) + + +POWER_KEY = (EventTypes.PowerLevels, "") + + +@defer.inlineCallbacks +def resolve_events_with_factory(state_sets, event_map, state_map_factory): + """ + Args: + state_sets(list): List of dicts of (type, state_key) -> event_id, + which are the different state groups to resolve. + + event_map(dict[str,FrozenEvent]|None): + a dict from event_id to event, for any events that we happen to + have in flight (eg, those currently being persisted). This will be + used as a starting point fof finding the state we need; any missing + events will be requested via state_map_factory. + + If None, all events will be fetched via state_map_factory. + + state_map_factory(func): will be called + with a list of event_ids that are needed, and should return with + a Deferred of dict of event_id to event. + + Returns + Deferred[dict[(str, str), str]]: + a map from (type, state_key) to event_id. + """ + if len(state_sets) == 1: + defer.returnValue(state_sets[0]) + + unconflicted_state, conflicted_state = _seperate( + state_sets, + ) + + needed_events = set( + event_id + for event_ids in itervalues(conflicted_state) + for event_id in event_ids + ) + if event_map is not None: + needed_events -= set(iterkeys(event_map)) + + logger.info("Asking for %d conflicted events", len(needed_events)) + + # dict[str, FrozenEvent]: a map from state event id to event. Only includes + # the state events which are in conflict (and those in event_map) + state_map = yield state_map_factory(needed_events) + if event_map is not None: + state_map.update(event_map) + + # get the ids of the auth events which allow us to authenticate the + # conflicted state, picking only from the unconflicting state. + # + # dict[(str, str), str]: a map from state key to event id + auth_events = _create_auth_events_from_maps( + unconflicted_state, conflicted_state, state_map + ) + + new_needed_events = set(itervalues(auth_events)) + new_needed_events -= needed_events + if event_map is not None: + new_needed_events -= set(iterkeys(event_map)) + + logger.info("Asking for %d auth events", len(new_needed_events)) + + state_map_new = yield state_map_factory(new_needed_events) + state_map.update(state_map_new) + + defer.returnValue(_resolve_with_state( + unconflicted_state, conflicted_state, auth_events, state_map + )) + + +def _seperate(state_sets): + """Takes the state_sets and figures out which keys are conflicted and + which aren't. i.e., which have multiple different event_ids associated + with them in different state sets. + + Args: + state_sets(iterable[dict[(str, str), str]]): + List of dicts of (type, state_key) -> event_id, which are the + different state groups to resolve. + + Returns: + (dict[(str, str), str], dict[(str, str), set[str]]): + A tuple of (unconflicted_state, conflicted_state), where: + + unconflicted_state is a dict mapping (type, state_key)->event_id + for unconflicted state keys. + + conflicted_state is a dict mapping (type, state_key) to a set of + event ids for conflicted state keys. + """ + state_set_iterator = iter(state_sets) + unconflicted_state = dict(next(state_set_iterator)) + conflicted_state = {} + + for state_set in state_set_iterator: + for key, value in iteritems(state_set): + # Check if there is an unconflicted entry for the state key. + unconflicted_value = unconflicted_state.get(key) + if unconflicted_value is None: + # There isn't an unconflicted entry so check if there is a + # conflicted entry. + ls = conflicted_state.get(key) + if ls is None: + # There wasn't a conflicted entry so haven't seen this key before. + # Therefore it isn't conflicted yet. + unconflicted_state[key] = value + else: + # This key is already conflicted, add our value to the conflict set. + ls.add(value) + elif unconflicted_value != value: + # If the unconflicted value is not the same as our value then we + # have a new conflict. So move the key from the unconflicted_state + # to the conflicted state. + conflicted_state[key] = {value, unconflicted_value} + unconflicted_state.pop(key, None) + + return unconflicted_state, conflicted_state + + +def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map): + auth_events = {} + for event_ids in itervalues(conflicted_state): + for event_id in event_ids: + if event_id in state_map: + keys = event_auth.auth_types_for_event(state_map[event_id]) + for key in keys: + if key not in auth_events: + event_id = unconflicted_state.get(key, None) + if event_id: + auth_events[key] = event_id + return auth_events + + +def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event_ids, + state_map): + conflicted_state = {} + for key, event_ids in iteritems(conflicted_state_ids): + events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map] + if len(events) > 1: + conflicted_state[key] = events + elif len(events) == 1: + unconflicted_state_ids[key] = events[0].event_id + + auth_events = { + key: state_map[ev_id] + for key, ev_id in iteritems(auth_event_ids) + if ev_id in state_map + } + + try: + resolved_state = _resolve_state_events( + conflicted_state, auth_events + ) + except Exception: + logger.exception("Failed to resolve state") + raise + + new_state = unconflicted_state_ids + for key, event in iteritems(resolved_state): + new_state[key] = event.event_id + + return new_state + + +def _resolve_state_events(conflicted_state, auth_events): + """ This is where we actually decide which of the conflicted state to + use. + + We resolve conflicts in the following order: + 1. power levels + 2. join rules + 3. memberships + 4. other events. + """ + resolved_state = {} + if POWER_KEY in conflicted_state: + events = conflicted_state[POWER_KEY] + logger.debug("Resolving conflicted power levels %r", events) + resolved_state[POWER_KEY] = _resolve_auth_events( + events, auth_events) + + auth_events.update(resolved_state) + + for key, events in iteritems(conflicted_state): + if key[0] == EventTypes.JoinRules: + logger.debug("Resolving conflicted join rules %r", events) + resolved_state[key] = _resolve_auth_events( + events, + auth_events + ) + + auth_events.update(resolved_state) + + for key, events in iteritems(conflicted_state): + if key[0] == EventTypes.Member: + logger.debug("Resolving conflicted member lists %r", events) + resolved_state[key] = _resolve_auth_events( + events, + auth_events + ) + + auth_events.update(resolved_state) + + for key, events in iteritems(conflicted_state): + if key not in resolved_state: + logger.debug("Resolving conflicted state %r:%r", key, events) + resolved_state[key] = _resolve_normal_events( + events, auth_events + ) + + return resolved_state + + +def _resolve_auth_events(events, auth_events): + reverse = [i for i in reversed(_ordered_events(events))] + + auth_keys = set( + key + for event in events + for key in event_auth.auth_types_for_event(event) + ) + + new_auth_events = {} + for key in auth_keys: + auth_event = auth_events.get(key, None) + if auth_event: + new_auth_events[key] = auth_event + + auth_events = new_auth_events + + prev_event = reverse[0] + for event in reverse[1:]: + auth_events[(prev_event.type, prev_event.state_key)] = prev_event + try: + # The signatures have already been checked at this point + event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) + prev_event = event + except AuthError: + return prev_event + + return event + + +def _resolve_normal_events(events, auth_events): + for event in _ordered_events(events): + try: + # The signatures have already been checked at this point + event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) + return event + except AuthError: + pass + + # Use the last event (the one with the least depth) if they all fail + # the auth check. + return event + + +def _ordered_events(events): + def key_func(e): + return -int(e.depth), hashlib.sha1(e.event_id.encode('ascii')).hexdigest() + + return sorted(events, key=key_func) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 08dffd774f..be61147b9b 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -17,9 +17,10 @@ import sys import threading import time -from six import iteritems, iterkeys, itervalues +from six import PY2, iteritems, iterkeys, itervalues from six.moves import intern, range +from canonicaljson import json from prometheus_client import Histogram from twisted.internet import defer @@ -1216,3 +1217,32 @@ class _RollbackButIsFineException(Exception): something went wrong. """ pass + + +def db_to_json(db_content): + """ + Take some data from a database row and return a JSON-decoded object. + + Args: + db_content (memoryview|buffer|bytes|bytearray|unicode) + """ + # psycopg2 on Python 3 returns memoryview objects, which we need to + # cast to bytes to decode + if isinstance(db_content, memoryview): + db_content = db_content.tobytes() + + # psycopg2 on Python 2 returns buffer objects, which we need to cast to + # bytes to decode + if PY2 and isinstance(db_content, buffer): + db_content = bytes(db_content) + + # Decode it to a Unicode string before feeding it to json.loads, so we + # consistenty get a Unicode-containing object out. + if isinstance(db_content, (bytes, bytearray)): + db_content = db_content.decode('utf8') + + try: + return json.loads(db_content) + except Exception: + logging.warning("Tried to decode '%r' as JSON and failed", db_content) + raise diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index 73646da025..e06b0bc56d 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -169,7 +169,7 @@ class DeviceInboxStore(BackgroundUpdateStore): local_by_user_then_device = {} for user_id, messages_by_device in messages_by_user_then_device.items(): messages_json_for_user = {} - devices = messages_by_device.keys() + devices = list(messages_by_device.keys()) if len(devices) == 1 and devices[0] == "*": # Handle wildcard device_ids. sql = ( diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index c0943ecf91..d10ff9e4b9 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -24,7 +24,7 @@ from synapse.api.errors import StoreError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList -from ._base import Cache, SQLBaseStore +from ._base import Cache, SQLBaseStore, db_to_json logger = logging.getLogger(__name__) @@ -411,7 +411,7 @@ class DeviceStore(SQLBaseStore): if device is not None: key_json = device.get("key_json", None) if key_json: - result["keys"] = json.loads(key_json) + result["keys"] = db_to_json(key_json) device_display_name = device.get("device_display_name", None) if device_display_name: result["device_display_name"] = device_display_name @@ -466,7 +466,7 @@ class DeviceStore(SQLBaseStore): retcol="content", desc="_get_cached_user_device", ) - defer.returnValue(json.loads(content)) + defer.returnValue(db_to_json(content)) @cachedInlineCallbacks() def _get_cached_devices_for_user(self, user_id): @@ -479,7 +479,7 @@ class DeviceStore(SQLBaseStore): desc="_get_cached_devices_for_user", ) defer.returnValue({ - device["device_id"]: json.loads(device["content"]) + device["device_id"]: db_to_json(device["content"]) for device in devices }) @@ -511,7 +511,7 @@ class DeviceStore(SQLBaseStore): key_json = device.get("key_json", None) if key_json: - result["keys"] = json.loads(key_json) + result["keys"] = db_to_json(key_json) device_display_name = device.get("device_display_name", None) if device_display_name: result["device_display_name"] = device_display_name diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 523b4360c3..1f1721e820 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -14,13 +14,13 @@ # limitations under the License. from six import iteritems -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from twisted.internet import defer from synapse.util.caches.descriptors import cached -from ._base import SQLBaseStore +from ._base import SQLBaseStore, db_to_json class EndToEndKeyStore(SQLBaseStore): @@ -90,7 +90,7 @@ class EndToEndKeyStore(SQLBaseStore): for user_id, device_keys in iteritems(results): for device_id, device_info in iteritems(device_keys): - device_info["keys"] = json.loads(device_info.pop("key_json")) + device_info["keys"] = db_to_json(device_info.pop("key_json")) defer.returnValue(results) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 8a0386c1a4..42225f8a2a 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -41,13 +41,18 @@ class PostgresEngine(object): db_conn.set_isolation_level( self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ ) + + # Set the bytea output to escape, vs the default of hex + cursor = db_conn.cursor() + cursor.execute("SET bytea_output TO escape") + # Asynchronous commit, don't wait for the server to call fsync before # ending the transaction. # https://www.postgresql.org/docs/current/static/wal-async-commit.html if not self.synchronous_commit: - cursor = db_conn.cursor() cursor.execute("SET synchronous_commit TO OFF") - cursor.close() + + cursor.close() def is_deadlock(self, error): if isinstance(error, self.module.DatabaseError): diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 025a7fb6d9..8bf87f38f7 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -19,7 +19,7 @@ import logging from collections import OrderedDict, deque, namedtuple from functools import wraps -from six import iteritems +from six import iteritems, text_type from six.moves import range from canonicaljson import json @@ -705,9 +705,11 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore } events_map = {ev.event_id: ev for ev, _ in events_context} + room_version = yield self.get_room_version(room_id) + logger.debug("calling resolve_state_groups from preserve_events") res = yield self._state_resolution_handler.resolve_state_groups( - room_id, state_groups, events_map, get_events + room_id, room_version, state_groups, events_map, get_events ) defer.returnValue((res.state, None)) @@ -1218,7 +1220,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore "sender": event.sender, "contains_url": ( "url" in event.content - and isinstance(event.content["url"], basestring) + and isinstance(event.content["url"], text_type) ), } for event, _ in events_and_contexts @@ -1527,7 +1529,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore contains_url = "url" in content if contains_url: - contains_url &= isinstance(content["url"], basestring) + contains_url &= isinstance(content["url"], text_type) except (KeyError, AttributeError): # If the event is missing a necessary field then # skip over it. @@ -1908,9 +1910,9 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore (room_id,) ) rows = txn.fetchall() - max_depth = max(row[0] for row in rows) + max_depth = max(row[1] for row in rows) - if max_depth <= token.topological: + if max_depth < token.topological: # We need to ensure we don't delete all the events from the database # otherwise we wouldn't be able to send any events (due to not # having any backwards extremeties) diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 59822178ff..a8326f5296 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import itertools import logging from collections import namedtuple @@ -265,7 +266,7 @@ class EventsWorkerStore(SQLBaseStore): """ with Measure(self._clock, "_fetch_event_list"): try: - event_id_lists = zip(*event_list)[0] + event_id_lists = list(zip(*event_list))[0] event_ids = [ item for sublist in event_id_lists for item in sublist ] @@ -299,14 +300,14 @@ class EventsWorkerStore(SQLBaseStore): logger.exception("do_fetch") # We only want to resolve deferreds from the main thread - def fire(evs): + def fire(evs, exc): for _, d in evs: if not d.called: with PreserveLoggingContext(): - d.errback(e) + d.errback(exc) with PreserveLoggingContext(): - self.hs.get_reactor().callFromThread(fire, event_list) + self.hs.get_reactor().callFromThread(fire, event_list, e) @defer.inlineCallbacks def _enqueue_events(self, events, check_redacted=True, allow_rejected=False): diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py index 2d5896c5b4..6ddcc909bf 100644 --- a/synapse/storage/filtering.py +++ b/synapse/storage/filtering.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from twisted.internet import defer from synapse.api.errors import Codes, SynapseError from synapse.util.caches.descriptors import cachedInlineCallbacks -from ._base import SQLBaseStore +from ._base import SQLBaseStore, db_to_json class FilteringStore(SQLBaseStore): @@ -44,7 +44,7 @@ class FilteringStore(SQLBaseStore): desc="get_user_filter", ) - defer.returnValue(json.loads(bytes(def_json).decode("utf-8"))) + defer.returnValue(db_to_json(def_json)) def add_user_filter(self, user_localpart, user_filter): def_json = encode_canonical_json(user_filter) diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 06f9a75a97..c7899d7fd2 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -36,7 +36,6 @@ class MonthlyActiveUsersStore(SQLBaseStore): @defer.inlineCallbacks def initialise_reserved_users(self, threepids): - # TODO Why can't I do this in init? store = self.hs.get_datastore() reserved_user_list = [] @@ -147,6 +146,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): return count return self.runInteraction("count_users", _count_users) + @defer.inlineCallbacks def upsert_monthly_active_user(self, user_id): """ Updates or inserts monthly active user member @@ -155,7 +155,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): Deferred[bool]: True if a new entry was created, False if an existing one was updated. """ - is_insert = self._simple_upsert( + is_insert = yield self._simple_upsert( desc="upsert_monthly_active_user", table="monthly_active_users", keyvalues={ @@ -200,6 +200,11 @@ class MonthlyActiveUsersStore(SQLBaseStore): user_id(str): the user_id to query """ if self.hs.config.limit_usage_by_mau: + is_trial = yield self.is_trial_user(user_id) + if is_trial: + # we don't track trial users in the MAU table. + return + last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id) now = self.hs.get_clock().time_msec() diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py index 60295da254..88b50f33b5 100644 --- a/synapse/storage/profile.py +++ b/synapse/storage/profile.py @@ -71,8 +71,6 @@ class ProfileWorkerStore(SQLBaseStore): desc="get_from_remote_profile_cache", ) - -class ProfileStore(ProfileWorkerStore): def create_profile(self, user_localpart): return self._simple_insert( table="profiles", @@ -96,6 +94,8 @@ class ProfileStore(ProfileWorkerStore): desc="set_profile_avatar_url", ) + +class ProfileStore(ProfileWorkerStore): def add_remote_profile_cache(self, user_id, displayname, avatar_url): """Ensure we are caching the remote user's profiles. diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 8443bd4c1b..c7987bfcdd 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -15,7 +15,8 @@ # limitations under the License. import logging -import types + +import six from canonicaljson import encode_canonical_json, json @@ -27,6 +28,11 @@ from ._base import SQLBaseStore logger = logging.getLogger(__name__) +if six.PY2: + db_binary_type = buffer +else: + db_binary_type = memoryview + class PusherWorkerStore(SQLBaseStore): def _decode_pushers_rows(self, rows): @@ -34,18 +40,18 @@ class PusherWorkerStore(SQLBaseStore): dataJson = r['data'] r['data'] = None try: - if isinstance(dataJson, types.BufferType): + if isinstance(dataJson, db_binary_type): dataJson = str(dataJson).decode("UTF8") r['data'] = json.loads(dataJson) except Exception as e: logger.warn( "Invalid JSON in data for pusher %d: %s, %s", - r['id'], dataJson, e.message, + r['id'], dataJson, e.args[0], ) pass - if isinstance(r['pushkey'], types.BufferType): + if isinstance(r['pushkey'], db_binary_type): r['pushkey'] = str(r['pushkey']).decode("UTF8") return rows diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 07333f777d..26b429e307 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -26,6 +26,11 @@ from synapse.util.caches.descriptors import cached, cachedInlineCallbacks class RegistrationWorkerStore(SQLBaseStore): + def __init__(self, db_conn, hs): + super(RegistrationWorkerStore, self).__init__(db_conn, hs) + + self.config = hs.config + @cached() def get_user_by_id(self, user_id): return self._simple_select_one( @@ -36,12 +41,33 @@ class RegistrationWorkerStore(SQLBaseStore): retcols=[ "name", "password_hash", "is_guest", "consent_version", "consent_server_notice_sent", - "appservice_id", + "appservice_id", "creation_ts", ], allow_none=True, desc="get_user_by_id", ) + @defer.inlineCallbacks + def is_trial_user(self, user_id): + """Checks if user is in the "trial" period, i.e. within the first + N days of registration defined by `mau_trial_days` config + + Args: + user_id (str) + + Returns: + Deferred[bool] + """ + + info = yield self.get_user_by_id(user_id) + if not info: + defer.returnValue(False) + + now = self.clock.time_msec() + trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000 + is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms + defer.returnValue(is_trial) + @cached() def get_user_by_access_token(self, token): """Get a user from the given access token. diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 3378fc77d1..61013b8919 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -186,6 +186,35 @@ class RoomWorkerStore(SQLBaseStore): desc="is_room_blocked", ) + @cachedInlineCallbacks(max_entries=10000) + def get_ratelimit_for_user(self, user_id): + """Check if there are any overrides for ratelimiting for the given + user + + Args: + user_id (str) + + Returns: + RatelimitOverride if there is an override, else None. If the contents + of RatelimitOverride are None or 0 then ratelimitng has been + disabled for that user entirely. + """ + row = yield self._simple_select_one( + table="ratelimit_override", + keyvalues={"user_id": user_id}, + retcols=("messages_per_second", "burst_count"), + allow_none=True, + desc="get_ratelimit_for_user", + ) + + if row: + defer.returnValue(RatelimitOverride( + messages_per_second=row["messages_per_second"], + burst_count=row["burst_count"], + )) + else: + defer.returnValue(None) + class RoomStore(RoomWorkerStore, SearchStore): @@ -469,35 +498,6 @@ class RoomStore(RoomWorkerStore, SearchStore): "get_all_new_public_rooms", get_all_new_public_rooms ) - @cachedInlineCallbacks(max_entries=10000) - def get_ratelimit_for_user(self, user_id): - """Check if there are any overrides for ratelimiting for the given - user - - Args: - user_id (str) - - Returns: - RatelimitOverride if there is an override, else None. If the contents - of RatelimitOverride are None or 0 then ratelimitng has been - disabled for that user entirely. - """ - row = yield self._simple_select_one( - table="ratelimit_override", - keyvalues={"user_id": user_id}, - retcols=("messages_per_second", "burst_count"), - allow_none=True, - desc="get_ratelimit_for_user", - ) - - if row: - defer.returnValue(RatelimitOverride( - messages_per_second=row["messages_per_second"], - burst_count=row["burst_count"], - )) - else: - defer.returnValue(None) - @defer.inlineCallbacks def block_room(self, room_id, user_id): yield self._simple_insert( diff --git a/synapse/storage/state.py b/synapse/storage/state.py index dd03c4168b..4b971efdba 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -60,8 +60,43 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): def __init__(self, db_conn, hs): super(StateGroupWorkerStore, self).__init__(db_conn, hs) + # Originally the state store used a single DictionaryCache to cache the + # event IDs for the state types in a given state group to avoid hammering + # on the state_group* tables. + # + # The point of using a DictionaryCache is that it can cache a subset + # of the state events for a given state group (i.e. a subset of the keys for a + # given dict which is an entry in the cache for a given state group ID). + # + # However, this poses problems when performing complicated queries + # on the store - for instance: "give me all the state for this group, but + # limit members to this subset of users", as DictionaryCache's API isn't + # rich enough to say "please cache any of these fields, apart from this subset". + # This is problematic when lazy loading members, which requires this behaviour, + # as without it the cache has no choice but to speculatively load all + # state events for the group, which negates the efficiency being sought. + # + # Rather than overcomplicating DictionaryCache's API, we instead split the + # state_group_cache into two halves - one for tracking non-member events, + # and the other for tracking member_events. This means that lazy loading + # queries can be made in a cache-friendly manner by querying both caches + # separately and then merging the result. So for the example above, you + # would query the members cache for a specific subset of state keys + # (which DictionaryCache will handle efficiently and fine) and the non-members + # cache for all state (which DictionaryCache will similarly handle fine) + # and then just merge the results together. + # + # We size the non-members cache to be smaller than the members cache as the + # vast majority of state in Matrix (today) is member events. + self._state_group_cache = DictionaryCache( - "*stateGroupCache*", 500000 * get_cache_factor_for("stateGroupCache") + "*stateGroupCache*", + # TODO: this hasn't been tuned yet + 50000 * get_cache_factor_for("stateGroupCache") + ) + self._state_group_members_cache = DictionaryCache( + "*stateGroupMembersCache*", + 500000 * get_cache_factor_for("stateGroupMembersCache") ) @defer.inlineCallbacks @@ -275,7 +310,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): }) @defer.inlineCallbacks - def _get_state_groups_from_groups(self, groups, types): + def _get_state_groups_from_groups(self, groups, types, members=None): """Returns the state groups for a given set of groups, filtering on types of state events. @@ -284,6 +319,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): types (Iterable[str, str|None]|None): list of 2-tuples of the form (`type`, `state_key`), where a `state_key` of `None` matches all state_keys for the `type`. If None, all types are returned. + members (bool|None): If not None, then, in addition to any filtering + implied by types, the results are also filtered to only include + member events (if True), or to exclude member events (if False) Returns: dictionary state_group -> (dict of (type, state_key) -> event id) @@ -294,14 +332,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): for chunk in chunks: res = yield self.runInteraction( "_get_state_groups_from_groups", - self._get_state_groups_from_groups_txn, chunk, types, + self._get_state_groups_from_groups_txn, chunk, types, members, ) results.update(res) defer.returnValue(results) def _get_state_groups_from_groups_txn( - self, txn, groups, types=None, + self, txn, groups, types=None, members=None, ): results = {group: {} for group in groups} @@ -339,6 +377,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): %s """) + if members is True: + sql += " AND type = '%s'" % (EventTypes.Member,) + elif members is False: + sql += " AND type <> '%s'" % (EventTypes.Member,) + # Turns out that postgres doesn't like doing a list of OR's and # is about 1000x slower, so we just issue a query for each specific # type seperately. @@ -386,6 +429,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): else: where_clause = "" + if members is True: + where_clause += " AND type = '%s'" % EventTypes.Member + elif members is False: + where_clause += " AND type <> '%s'" % EventTypes.Member + # We don't use WITH RECURSIVE on sqlite3 as there are distributions # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) for group in groups: @@ -580,10 +628,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): defer.returnValue({row["event_id"]: row["state_group"] for row in rows}) - def _get_some_state_from_cache(self, group, types, filtered_types=None): + def _get_some_state_from_cache(self, cache, group, types, filtered_types=None): """Checks if group is in cache. See `_get_state_for_groups` Args: + cache(DictionaryCache): the state group cache to use group(int): The state group to lookup types(list[str, str|None]): List of 2-tuples of the form (`type`, `state_key`), where a `state_key` of `None` matches all @@ -597,11 +646,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): requests state from the cache, if False we need to query the DB for the missing state. """ - is_all, known_absent, state_dict_ids = self._state_group_cache.get(group) + is_all, known_absent, state_dict_ids = cache.get(group) type_to_key = {} - # tracks whether any of ourrequested types are missing from the cache + # tracks whether any of our requested types are missing from the cache missing_types = False for typ, state_key in types: @@ -648,7 +697,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): if include(k[0], k[1]) }, got_all - def _get_all_state_from_cache(self, group): + def _get_all_state_from_cache(self, cache, group): """Checks if group is in cache. See `_get_state_for_groups` Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool @@ -656,9 +705,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): cache, if False we need to query the DB for the missing state. Args: + cache(DictionaryCache): the state group cache to use group: The state group to lookup """ - is_all, _, state_dict_ids = self._state_group_cache.get(group) + is_all, _, state_dict_ids = cache.get(group) return state_dict_ids, is_all @@ -685,6 +735,62 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Deferred[dict[int, dict[(type, state_key), EventBase]]] a dictionary mapping from state group to state dictionary. """ + if types is not None: + non_member_types = [t for t in types if t[0] != EventTypes.Member] + + if filtered_types is not None and EventTypes.Member not in filtered_types: + # we want all of the membership events + member_types = None + else: + member_types = [t for t in types if t[0] == EventTypes.Member] + + else: + non_member_types = None + member_types = None + + non_member_state = yield self._get_state_for_groups_using_cache( + groups, self._state_group_cache, non_member_types, filtered_types, + ) + # XXX: we could skip this entirely if member_types is [] + member_state = yield self._get_state_for_groups_using_cache( + # we set filtered_types=None as member_state only ever contain members. + groups, self._state_group_members_cache, member_types, None, + ) + + state = non_member_state + for group in groups: + state[group].update(member_state[group]) + + defer.returnValue(state) + + @defer.inlineCallbacks + def _get_state_for_groups_using_cache( + self, groups, cache, types=None, filtered_types=None + ): + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key, querying from a specific cache. + + Args: + groups (iterable[int]): list of state groups for which we want + to get the state. + cache (DictionaryCache): the cache of group ids to state dicts which + we will pass through - either the normal state cache or the specific + members state cache. + types (None|iterable[(str, None|str)]): + indicates the state type/keys required. If None, the whole + state is fetched and returned. + + Otherwise, each entry should be a `(type, state_key)` tuple to + include in the response. A `state_key` of None is a wildcard + meaning that we require all state with that type. + filtered_types(list[str]|None): Only apply filtering via `types` to this + list of event types. Other types of events are returned unfiltered. + If None, `types` filtering is applied to all events. + + Returns: + Deferred[dict[int, dict[(type, state_key), EventBase]]] + a dictionary mapping from state group to state dictionary. + """ if types: types = frozenset(types) results = {} @@ -692,7 +798,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): if types is not None: for group in set(groups): state_dict_ids, got_all = self._get_some_state_from_cache( - group, types, filtered_types + cache, group, types, filtered_types ) results[group] = state_dict_ids @@ -701,7 +807,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): else: for group in set(groups): state_dict_ids, got_all = self._get_all_state_from_cache( - group + cache, group ) results[group] = state_dict_ids @@ -710,8 +816,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): missing_groups.append(group) if missing_groups: - # Okay, so we have some missing_types, lets fetch them. - cache_seq_num = self._state_group_cache.sequence + # Okay, so we have some missing_types, let's fetch them. + cache_seq_num = cache.sequence # the DictionaryCache knows if it has *all* the state, but # does not know if it has all of the keys of a particular type, @@ -725,7 +831,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): types_to_fetch = types group_to_state_dict = yield self._get_state_groups_from_groups( - missing_groups, types_to_fetch + missing_groups, types_to_fetch, cache == self._state_group_members_cache, ) for group, group_state_dict in iteritems(group_to_state_dict): @@ -745,7 +851,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # update the cache with all the things we fetched from the # database. - self._state_group_cache.update( + cache.update( cache_seq_num, key=group, value=group_state_dict, @@ -847,15 +953,33 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ], ) - # Prefill the state group cache with this group. + # Prefill the state group caches with this group. # It's fine to use the sequence like this as the state group map # is immutable. (If the map wasn't immutable then this prefill could # race with another update) + + current_member_state_ids = { + s: ev + for (s, ev) in iteritems(current_state_ids) + if s[0] == EventTypes.Member + } + txn.call_after( + self._state_group_members_cache.update, + self._state_group_members_cache.sequence, + key=state_group, + value=dict(current_member_state_ids), + ) + + current_non_member_state_ids = { + s: ev + for (s, ev) in iteritems(current_state_ids) + if s[0] != EventTypes.Member + } txn.call_after( self._state_group_cache.update, self._state_group_cache.sequence, key=state_group, - value=dict(current_state_ids), + value=dict(current_non_member_state_ids), ) return state_group diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 428e7fa36e..0c42bd3322 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -18,14 +18,14 @@ from collections import namedtuple import six -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.caches.descriptors import cached -from ._base import SQLBaseStore +from ._base import SQLBaseStore, db_to_json # py2 sqlite has buffer hardcoded as only binary type, so we must use it, # despite being deprecated and removed in favor of memoryview @@ -95,7 +95,8 @@ class TransactionStore(SQLBaseStore): ) if result and result["response_code"]: - return result["response_code"], json.loads(str(result["response_json"])) + return result["response_code"], db_to_json(result["response_json"]) + else: return None |