diff options
author | Paul "LeoNerd" Evans <paul@matrix.org> | 2015-12-10 16:21:00 +0000 |
---|---|---|
committer | Paul "LeoNerd" Evans <paul@matrix.org> | 2015-12-10 16:21:00 +0000 |
commit | d7ee7b589f0535c21301f38e93b0cabc0cf288d4 (patch) | |
tree | fcd7d110dc66d5e175f1030d10e0bbd5624bbf3c /synapse | |
parent | Don't complain if /make_join response lacks 'prev_state' list (SYN-517) (diff) | |
parent | Merge pull request #432 from matrix-org/pushrules_refactor (diff) | |
download | synapse-d7ee7b589f0535c21301f38e93b0cabc0cf288d4.tar.xz |
Merge branch 'develop' into paul/tiny-fixes
Diffstat (limited to 'synapse')
80 files changed, 2210 insertions, 862 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index f68a15bb85..3e7e26bf60 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.11.0-rc1" +__version__ = "0.11.1" diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 3e891a6193..b9c3e6d2c4 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -207,6 +207,13 @@ class Auth(object): user_id, room_id )) + if membership == Membership.LEAVE: + forgot = yield self.store.did_forget(user_id, room_id) + if forgot: + raise AuthError(403, "User %s not in room %s" % ( + user_id, room_id + )) + defer.returnValue(member) @defer.inlineCallbacks @@ -587,7 +594,7 @@ class Auth(object): def _get_user_from_macaroon(self, macaroon_str): try: macaroon = pymacaroons.Macaroon.deserialize(macaroon_str) - self._validate_macaroon(macaroon) + self.validate_macaroon(macaroon, "access", False) user_prefix = "user_id = " user = None @@ -635,13 +642,27 @@ class Auth(object): errcode=Codes.UNKNOWN_TOKEN ) - def _validate_macaroon(self, macaroon): + def validate_macaroon(self, macaroon, type_string, verify_expiry): + """ + validate that a Macaroon is understood by and was signed by this server. + + Args: + macaroon(pymacaroons.Macaroon): The macaroon to validate + type_string(str): The kind of token this is (e.g. "access", "refresh") + verify_expiry(bool): Whether to verify whether the macaroon has expired. + This should really always be True, but no clients currently implement + token refresh, so we can't enforce expiry yet. + """ v = pymacaroons.Verifier() v.satisfy_exact("gen = 1") - v.satisfy_exact("type = access") + v.satisfy_exact("type = " + type_string) v.satisfy_general(lambda c: c.startswith("user_id = ")) - v.satisfy_general(self._verify_expiry) v.satisfy_exact("guest = true") + if verify_expiry: + v.satisfy_general(self._verify_expiry) + else: + v.satisfy_general(lambda c: c.startswith("time < ")) + v.verify(macaroon, self.hs.config.macaroon_secret_key) v = pymacaroons.Verifier() @@ -652,9 +673,6 @@ class Auth(object): prefix = "time < " if not caveat.startswith(prefix): return False - # TODO(daniel): Enable expiry check when clients actually know how to - # refresh tokens. (And remember to enable the tests) - return True expiry = int(caveat[len(prefix):]) now = self.hs.get_clock().time_msec() return now < expiry @@ -842,7 +860,7 @@ class Auth(object): redact_level = self._get_named_level(auth_events, "redact", 50) - if user_level > redact_level: + if user_level >= redact_level: return False redacter_domain = EventID.from_string(event.event_id).domain diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index aaa2433cae..bc03d6c287 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -50,11 +50,11 @@ class Filtering(object): # many definitions. top_level_definitions = [ - "presence" + "presence", "account_data" ] room_level_definitions = [ - "state", "timeline", "ephemeral", "private_user_data" + "state", "timeline", "ephemeral", "account_data" ] for key in top_level_definitions: @@ -131,14 +131,22 @@ class FilterCollection(object): self.filter_json.get("room", {}).get("ephemeral", {}) ) - self.room_private_user_data = Filter( - self.filter_json.get("room", {}).get("private_user_data", {}) + self.room_account_data = Filter( + self.filter_json.get("room", {}).get("account_data", {}) ) self.presence_filter = Filter( self.filter_json.get("presence", {}) ) + self.account_data = Filter( + self.filter_json.get("account_data", {}) + ) + + self.include_leave = self.filter_json.get("room", {}).get( + "include_leave", False + ) + def timeline_limit(self): return self.room_timeline_filter.limit() @@ -151,6 +159,9 @@ class FilterCollection(object): def filter_presence(self, events): return self.presence_filter.filter(events) + def filter_account_data(self, events): + return self.account_data.filter(events) + def filter_room_state(self, events): return self.room_state_filter.filter(events) @@ -160,8 +171,8 @@ class FilterCollection(object): def filter_room_ephemeral(self, events): return self.room_ephemeral_filter.filter(events) - def filter_room_private_user_data(self, events): - return self.room_private_user_data.filter(events) + def filter_room_account_data(self, events): + return self.room_account_data.filter(events) class Filter(object): diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index cd7a52ec07..0807def6ca 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -15,6 +15,8 @@ # limitations under the License. import sys +from synapse.rest import ClientRestResource + sys.dont_write_bytecode = True from synapse.python_dependencies import ( check_requirements, DEPENDENCY_LINKS, MissingRequirementError @@ -53,15 +55,13 @@ from synapse.rest.key.v1.server_key_resource import LocalKey from synapse.rest.key.v2 import KeyApiV2Resource from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.api.urls import ( - CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX, - SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, STATIC_PREFIX, + FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX, + SERVER_KEY_PREFIX, MEDIA_PREFIX, STATIC_PREFIX, SERVER_KEY_V2_PREFIX, ) from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory from synapse.util.logcontext import LoggingContext -from synapse.rest.client.v1 import ClientV1RestResource -from synapse.rest.client.v2_alpha import ClientV2AlphaRestResource from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse import events @@ -92,11 +92,8 @@ class SynapseHomeServer(HomeServer): def build_http_client(self): return MatrixFederationHttpClient(self) - def build_resource_for_client(self): - return ClientV1RestResource(self) - - def build_resource_for_client_v2_alpha(self): - return ClientV2AlphaRestResource(self) + def build_client_resource(self): + return ClientRestResource(self) def build_resource_for_federation(self): return JsonResource(self) @@ -179,16 +176,15 @@ class SynapseHomeServer(HomeServer): for res in listener_config["resources"]: for name in res["names"]: if name == "client": + client_resource = self.get_client_resource() if res["compress"]: - client_v1 = gz_wrap(self.get_resource_for_client()) - client_v2 = gz_wrap(self.get_resource_for_client_v2_alpha()) - else: - client_v1 = self.get_resource_for_client() - client_v2 = self.get_resource_for_client_v2_alpha() + client_resource = gz_wrap(client_resource) resources.update({ - CLIENT_PREFIX: client_v1, - CLIENT_V2_ALPHA_PREFIX: client_v2, + "/_matrix/client/api/v1": client_resource, + "/_matrix/client/r0": client_resource, + "/_matrix/client/unstable": client_resource, + "/_matrix/client/v2_alpha": client_resource, }) if name == "federation": @@ -499,13 +495,28 @@ class SynapseRequest(Request): self.start_time = int(time.time() * 1000) def finished_processing(self): + + try: + context = LoggingContext.current_context() + ru_utime, ru_stime = context.get_resource_usage() + db_txn_count = context.db_txn_count + db_txn_duration = context.db_txn_duration + except: + ru_utime, ru_stime = (0, 0) + db_txn_count, db_txn_duration = (0, 0) + self.site.access_logger.info( "%s - %s - {%s}" - " Processed request: %dms %sB %s \"%s %s %s\" \"%s\"", + " Processed request: %dms (%dms, %dms) (%dms/%d)" + " %sB %s \"%s %s %s\" \"%s\"", self.getClientIP(), self.site.site_tag, self.authenticated_entity, int(time.time() * 1000) - self.start_time, + int(ru_utime * 1000), + int(ru_stime * 1000), + int(db_txn_duration * 1000), + int(db_txn_count), self.sentLength, self.code, self.method, diff --git a/synapse/config/_base.py b/synapse/config/_base.py index c18e0bdbb8..d0c9972445 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -25,18 +25,29 @@ class ConfigError(Exception): pass -class Config(object): +# We split these messages out to allow packages to override with package +# specific instructions. +MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS = """\ +Please opt in or out of reporting anonymized homeserver usage statistics, by +setting the `report_stats` key in your config file to either True or False. +""" + +MISSING_REPORT_STATS_SPIEL = """\ +We would really appreciate it if you could help our project out by reporting +anonymized usage statistics from your homeserver. Only very basic aggregate +data (e.g. number of users) will be reported, but it helps us to track the +growth of the Matrix community, and helps us to make Matrix a success, as well +as to convince other networks that they should peer with us. + +Thank you. +""" + +MISSING_SERVER_NAME = """\ +Missing mandatory `server_name` config option. +""" - stats_reporting_begging_spiel = ( - "We would really appreciate it if you could help our project out by" - " reporting anonymized usage statistics from your homeserver. Only very" - " basic aggregate data (e.g. number of users) will be reported, but it" - " helps us to track the growth of the Matrix community, and helps us to" - " make Matrix a success, as well as to convince other networks that they" - " should peer with us." - "\nThank you." - ) +class Config(object): @staticmethod def parse_size(value): if isinstance(value, int) or isinstance(value, long): @@ -215,7 +226,7 @@ class Config(object): if config_args.report_stats is None: config_parser.error( "Please specify either --report-stats=yes or --report-stats=no\n\n" + - cls.stats_reporting_begging_spiel + MISSING_REPORT_STATS_SPIEL ) if not config_files: config_parser.error( @@ -290,6 +301,10 @@ class Config(object): yaml_config = cls.read_config_file(config_file) specified_config.update(yaml_config) + if "server_name" not in specified_config: + sys.stderr.write("\n" + MISSING_SERVER_NAME + "\n") + sys.exit(1) + server_name = specified_config["server_name"] _, config = obj.generate_config( config_dir_path=config_dir_path, @@ -299,11 +314,8 @@ class Config(object): config.update(specified_config) if "report_stats" not in config: sys.stderr.write( - "Please opt in or out of reporting anonymized homeserver usage " - "statistics, by setting the report_stats key in your config file " - " ( " + config_path + " ) " + - "to either True or False.\n\n" + - Config.stats_reporting_begging_spiel + "\n") + "\n" + MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" + + MISSING_REPORT_STATS_SPIEL + "\n") sys.exit(1) if generate_keys: diff --git a/synapse/config/cas.py b/synapse/config/cas.py index a337ae6ca0..326e405841 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -27,10 +27,12 @@ class CasConfig(Config): if cas_config: self.cas_enabled = cas_config.get("enabled", True) self.cas_server_url = cas_config["server_url"] + self.cas_service_url = cas_config["service_url"] self.cas_required_attributes = cas_config.get("required_attributes", {}) else: self.cas_enabled = False self.cas_server_url = None + self.cas_service_url = None self.cas_required_attributes = {} def default_config(self, config_dir_path, server_name, **kwargs): @@ -39,6 +41,7 @@ class CasConfig(Config): #cas_config: # enabled: true # server_url: "https://cas-server.com" + # service_url: "https://homesever.domain.com:8448" # #required_attributes: # # name: value """ diff --git a/synapse/config/server.py b/synapse/config/server.py index 5c2d6bfeab..187edd516b 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -133,6 +133,7 @@ class ServerConfig(Config): # The domain name of the server, with optional explicit port. # This is used by remote servers to connect to this server, # e.g. matrix.org, localhost:8080, etc. + # This is also the last part of your UserID. server_name: "%(server_name)s" # When running as a daemon, the file to store the pid in diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 8b6a59866f..bc5bb5cdb1 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -381,28 +381,24 @@ class Keyring(object): def get_server_verify_key_v2_indirect(self, server_names_and_key_ids, perspective_name, perspective_keys): - limiter = yield get_retry_limiter( - perspective_name, self.clock, self.store - ) - - with limiter: - # TODO(mark): Set the minimum_valid_until_ts to that needed by - # the events being validated or the current time if validating - # an incoming request. - query_response = yield self.client.post_json( - destination=perspective_name, - path=b"/_matrix/key/v2/query", - data={ - u"server_keys": { - server_name: { - key_id: { - u"minimum_valid_until_ts": 0 - } for key_id in key_ids - } - for server_name, key_ids in server_names_and_key_ids + # TODO(mark): Set the minimum_valid_until_ts to that needed by + # the events being validated or the current time if validating + # an incoming request. + query_response = yield self.client.post_json( + destination=perspective_name, + path=b"/_matrix/key/v2/query", + data={ + u"server_keys": { + server_name: { + key_id: { + u"minimum_valid_until_ts": 0 + } for key_id in key_ids } - }, - ) + for server_name, key_ids in server_names_and_key_ids + } + }, + long_retries=True, + ) keys = {} diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 9989b76591..e634b149ba 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -100,22 +100,20 @@ def format_event_raw(d): def format_event_for_client_v1(d): - d["user_id"] = d.pop("sender", None) + d = format_event_for_client_v2(d) + + sender = d.get("sender") + if sender is not None: + d["user_id"] = sender - move_keys = ( + copy_keys = ( "age", "redacted_because", "replaces_state", "prev_content", "invite_room_state", ) - for key in move_keys: + for key in copy_keys: if key in d["unsigned"]: d[key] = d["unsigned"][key] - drop_keys = ( - "auth_events", "prev_events", "hashes", "signatures", "depth", - "unsigned", "origin", "prev_state" - ) - for key in drop_keys: - d.pop(key, None) return d @@ -129,10 +127,9 @@ def format_event_for_client_v2(d): return d -def format_event_for_client_v2_without_event_id(d): +def format_event_for_client_v2_without_room_id(d): d = format_event_for_client_v2(d) d.pop("room_id", None) - d.pop("event_id", None) return d diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 3d59e1c650..0e0cb7ebc6 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -136,6 +136,7 @@ class TransportLayerClient(object): path=PREFIX + "/send/%s/" % transaction.transaction_id, data=json_data, json_data_callback=json_data_callback, + long_retries=True, ) logger.debug( diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 127b4da4f8..6b164fd2d1 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -165,7 +165,7 @@ class BaseFederationServlet(object): if code is None: continue - server.register_path(method, pattern, self._wrap(code)) + server.register_paths(method, (pattern,), self._wrap(code)) class FederationSendServlet(BaseFederationServlet): diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 6519f183df..5fd20285d2 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -92,7 +92,15 @@ class BaseHandler(object): membership_event = state.get((EventTypes.Member, user_id), None) if membership_event: - membership = membership_event.membership + was_forgotten_at_event = yield self.store.was_forgotten_at( + membership_event.state_key, + membership_event.room_id, + membership_event.event_id + ) + if was_forgotten_at_event: + membership = None + else: + membership = membership_event.membership else: membership = None diff --git a/synapse/handlers/private_user_data.py b/synapse/handlers/account_data.py index 1abe45ed7b..fe773bee9b 100644 --- a/synapse/handlers/private_user_data.py +++ b/synapse/handlers/account_data.py @@ -16,22 +16,23 @@ from twisted.internet import defer -class PrivateUserDataEventSource(object): +class AccountDataEventSource(object): def __init__(self, hs): self.store = hs.get_datastore() def get_current_key(self, direction='f'): - return self.store.get_max_private_user_data_stream_id() + return self.store.get_max_account_data_stream_id() @defer.inlineCallbacks def get_new_events(self, user, from_key, **kwargs): user_id = user.to_string() last_stream_id = from_key - current_stream_id = yield self.store.get_max_private_user_data_stream_id() - tags = yield self.store.get_updated_tags(user_id, last_stream_id) + current_stream_id = yield self.store.get_max_account_data_stream_id() results = [] + tags = yield self.store.get_updated_tags(user_id, last_stream_id) + for room_id, room_tags in tags.items(): results.append({ "type": "m.tag", @@ -39,6 +40,24 @@ class PrivateUserDataEventSource(object): "room_id": room_id, }) + account_data, room_account_data = ( + yield self.store.get_updated_account_data_for_user(user_id, last_stream_id) + ) + + for account_data_type, content in account_data.items(): + results.append({ + "type": account_data_type, + "content": content, + }) + + for room_id, account_data in room_account_data.items(): + for account_data_type, content in account_data.items(): + results.append({ + "type": account_data_type, + "content": content, + "room_id": room_id, + }) + defer.returnValue((results, current_stream_id)) @defer.inlineCallbacks diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index d852a18555..04fa58df65 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -30,34 +30,27 @@ class AdminHandler(BaseHandler): @defer.inlineCallbacks def get_whois(self, user): - res = yield self.store.get_user_ip_and_agents(user) - - d = {} - for r in res: - # Note that device_id is always None - device = d.setdefault(r["device_id"], {}) - session = device.setdefault(r["access_token"], []) - session.append({ - "ip": r["ip"], - "user_agent": r["user_agent"], - "last_seen": r["last_seen"], + connections = [] + + sessions = yield self.store.get_user_ip_and_agents(user) + for session in sessions: + connections.append({ + "ip": session["ip"], + "last_seen": session["last_seen"], + "user_agent": session["user_agent"], }) ret = { "user_id": user.to_string(), - "devices": [ - { - "device_id": k, + "devices": { + "": { "sessions": [ { - # "access_token": x, TODO (erikj) - "connections": y, + "connections": connections, } - for x, y in v.items() ] - } - for k, v in d.items() - ], + }, + }, } defer.returnValue(ret) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 1b11dbdffd..e64b67cdfd 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -18,7 +18,7 @@ from twisted.internet import defer from ._base import BaseHandler from synapse.api.constants import LoginType from synapse.types import UserID -from synapse.api.errors import LoginError, Codes +from synapse.api.errors import AuthError, LoginError, Codes from synapse.util.async import run_on_reactor from twisted.web.client import PartialDownloadError @@ -46,6 +46,7 @@ class AuthHandler(BaseHandler): } self.bcrypt_rounds = hs.config.bcrypt_rounds self.sessions = {} + self.INVALID_TOKEN_HTTP_STATUS = 401 @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): @@ -297,10 +298,11 @@ class AuthHandler(BaseHandler): defer.returnValue((user_id, access_token, refresh_token)) @defer.inlineCallbacks - def login_with_cas_user_id(self, user_id): + def get_login_tuple_for_user_id(self, user_id): """ - Authenticates the user with the given user ID, - intended to have been captured from a CAS response + Gets login tuple for the user with the given user ID. + The user is assumed to have been authenticated by some other + machanism (e.g. CAS) Args: user_id (str): User ID @@ -393,6 +395,23 @@ class AuthHandler(BaseHandler): )) return m.serialize() + def generate_short_term_login_token(self, user_id): + macaroon = self._generate_base_macaroon(user_id) + macaroon.add_first_party_caveat("type = login") + now = self.hs.get_clock().time_msec() + expiry = now + (2 * 60 * 1000) + macaroon.add_first_party_caveat("time < %d" % (expiry,)) + return macaroon.serialize() + + def validate_short_term_login_token_and_get_user_id(self, login_token): + try: + macaroon = pymacaroons.Macaroon.deserialize(login_token) + auth_api = self.hs.get_auth() + auth_api.validate_macaroon(macaroon, "login", True) + return self._get_user_from_macaroon(macaroon) + except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): + raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN) + def _generate_base_macaroon(self, user_id): macaroon = pymacaroons.Macaroon( location=self.hs.config.server_name, @@ -402,6 +421,16 @@ class AuthHandler(BaseHandler): macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) return macaroon + def _get_user_from_macaroon(self, macaroon): + user_prefix = "user_id = " + for caveat in macaroon.caveats: + if caveat.caveat_id.startswith(user_prefix): + return caveat.caveat_id[len(user_prefix):] + raise AuthError( + self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token", + errcode=Codes.UNKNOWN_TOKEN + ) + @defer.inlineCallbacks def set_password(self, user_id, newpassword): password_hash = self.hash(newpassword) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 0e4c0d4d06..fe300433e6 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -28,6 +28,18 @@ import random logger = logging.getLogger(__name__) +def started_user_eventstream(distributor, user): + return distributor.fire("started_user_eventstream", user) + + +def stopped_user_eventstream(distributor, user): + return distributor.fire("stopped_user_eventstream", user) + + +def user_joined_room(distributor, user, room_id): + return distributor.fire("user_joined_room", user, room_id) + + class EventStreamHandler(BaseHandler): def __init__(self, hs): @@ -66,7 +78,7 @@ class EventStreamHandler(BaseHandler): except: logger.exception("Failed to cancel event timer") else: - yield self.distributor.fire("started_user_eventstream", user) + yield started_user_eventstream(self.distributor, user) self._streams_per_user[user] += 1 @@ -89,7 +101,7 @@ class EventStreamHandler(BaseHandler): self._stop_timer_per_user.pop(user, None) - return self.distributor.fire("stopped_user_eventstream", user) + return stopped_user_eventstream(self.distributor, user) logger.debug("Scheduling _later: for %s", user) self._stop_timer_per_user[user] = ( @@ -120,9 +132,7 @@ class EventStreamHandler(BaseHandler): timeout = random.randint(int(timeout*0.9), int(timeout*1.1)) if is_guest: - yield self.distributor.fire( - "user_joined_room", user=auth_user, room_id=room_id - ) + yield user_joined_room(self.distributor, auth_user, room_id) events, tokens = yield self.notifier.get_events_for( auth_user, pagin_config, timeout, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index c1bce07e31..2855f2d7c3 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -44,6 +44,10 @@ import logging logger = logging.getLogger(__name__) +def user_joined_room(distributor, user, room_id): + return distributor.fire("user_joined_room", user, room_id) + + class FederationHandler(BaseHandler): """Handles events that originated from federation. Responsible for: @@ -60,10 +64,7 @@ class FederationHandler(BaseHandler): self.hs = hs - self.distributor.observe( - "user_joined_room", - self._on_user_joined - ) + self.distributor.observe("user_joined_room", self.user_joined_room) self.waiting_for_join_list = {} @@ -176,7 +177,7 @@ class FederationHandler(BaseHandler): ) try: - _, event_stream_id, max_stream_id = yield self._handle_new_event( + context, event_stream_id, max_stream_id = yield self._handle_new_event( origin, event, state=state, @@ -233,10 +234,13 @@ class FederationHandler(BaseHandler): if event.type == EventTypes.Member: if event.membership == Membership.JOIN: - user = UserID.from_string(event.state_key) - yield self.distributor.fire( - "user_joined_room", user=user, room_id=event.room_id - ) + prev_state = context.current_state.get((event.type, event.state_key)) + if not prev_state or prev_state.membership != Membership.JOIN: + # Only fire user_joined_room if the user has acutally + # joined the room. Don't bother if the user is just + # changing their profile info. + user = UserID.from_string(event.state_key) + yield user_joined_room(self.distributor, user, event.room_id) @defer.inlineCallbacks def _filter_events_for_server(self, server_name, room_id, events): @@ -733,9 +737,7 @@ class FederationHandler(BaseHandler): if event.type == EventTypes.Member: if event.content["membership"] == Membership.JOIN: user = UserID.from_string(event.state_key) - yield self.distributor.fire( - "user_joined_room", user=user, room_id=event.room_id - ) + yield user_joined_room(self.distributor, user, event.room_id) new_pdu = event @@ -1082,7 +1084,7 @@ class FederationHandler(BaseHandler): return self.store.get_min_depth(context) @log_function - def _on_user_joined(self, user, room_id): + def user_joined_room(self, user, room_id): waiters = self.waiting_for_join_list.get( (user.to_string(), room_id), [] diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 2a99921d5f..f1fa562fff 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -20,7 +20,6 @@ from synapse.api.errors import ( CodeMessageException ) from ._base import BaseHandler -from synapse.http.client import SimpleHttpClient from synapse.util.async import run_on_reactor from synapse.api.errors import SynapseError @@ -35,13 +34,12 @@ class IdentityHandler(BaseHandler): def __init__(self, hs): super(IdentityHandler, self).__init__(hs) + self.http_client = hs.get_simple_http_client() + @defer.inlineCallbacks def threepid_from_creds(self, creds): yield run_on_reactor() - # TODO: get this from the homeserver rather than creating a new one for - # each request - http_client = SimpleHttpClient(self.hs) # XXX: make this configurable! # trustedIdServers = ['matrix.org', 'localhost:8090'] trustedIdServers = ['matrix.org', 'vector.im'] @@ -67,7 +65,7 @@ class IdentityHandler(BaseHandler): data = {} try: - data = yield http_client.get_json( + data = yield self.http_client.get_json( "https://%s%s" % ( id_server, "/_matrix/identity/api/v1/3pid/getValidated3pid" @@ -85,7 +83,6 @@ class IdentityHandler(BaseHandler): def bind_threepid(self, creds, mxid): yield run_on_reactor() logger.debug("binding threepid %r to %s", creds, mxid) - http_client = SimpleHttpClient(self.hs) data = None if 'id_server' in creds: @@ -103,7 +100,7 @@ class IdentityHandler(BaseHandler): raise SynapseError(400, "No client_secret in creds") try: - data = yield http_client.post_urlencoded_get_json( + data = yield self.http_client.post_urlencoded_get_json( "https://%s%s" % ( id_server, "/_matrix/identity/api/v1/3pid/bind" ), @@ -121,7 +118,6 @@ class IdentityHandler(BaseHandler): @defer.inlineCallbacks def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs): yield run_on_reactor() - http_client = SimpleHttpClient(self.hs) params = { 'email': email, @@ -131,7 +127,7 @@ class IdentityHandler(BaseHandler): params.update(kwargs) try: - data = yield http_client.post_urlencoded_get_json( + data = yield self.http_client.post_urlencoded_get_json( "https://%s%s" % ( id_server, "/_matrix/identity/api/v1/validate/email/requestToken" diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 14051aee99..ccdd3d8473 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -26,11 +26,17 @@ from synapse.types import UserID, RoomStreamToken, StreamToken from ._base import BaseHandler +from canonicaljson import encode_canonical_json + import logging logger = logging.getLogger(__name__) +def collect_presencelike_data(distributor, user, content): + return distributor.fire("collect_presencelike_data", user, content) + + class MessageHandler(BaseHandler): def __init__(self, hs): @@ -195,10 +201,8 @@ class MessageHandler(BaseHandler): if membership == Membership.JOIN: joinee = UserID.from_string(builder.state_key) # If event doesn't include a display name, add one. - yield self.distributor.fire( - "collect_presencelike_data", - joinee, - builder.content + yield collect_presencelike_data( + self.distributor, joinee, builder.content ) if token_id is not None: @@ -211,6 +215,16 @@ class MessageHandler(BaseHandler): builder=builder, ) + if event.is_state(): + prev_state = context.current_state.get((event.type, event.state_key)) + if prev_state and event.user_id == prev_state.user_id: + prev_content = encode_canonical_json(prev_state.content) + next_content = encode_canonical_json(event.content) + if prev_content == next_content: + # Duplicate suppression for state updates with same sender + # and content. + defer.returnValue(prev_state) + if event.type == EventTypes.Member: member_handler = self.hs.get_handlers().room_member_handler yield member_handler.change_membership(event, context, is_guest=is_guest) @@ -359,6 +373,10 @@ class MessageHandler(BaseHandler): tags_by_room = yield self.store.get_tags_for_user(user_id) + account_data, account_data_by_room = ( + yield self.store.get_account_data_for_user(user_id) + ) + public_room_ids = yield self.store.get_public_room_ids() limit = pagin_config.limit @@ -436,14 +454,22 @@ class MessageHandler(BaseHandler): for c in current_state.values() ] - private_user_data = [] + account_data_events = [] tags = tags_by_room.get(event.room_id) if tags: - private_user_data.append({ + account_data_events.append({ "type": "m.tag", "content": {"tags": tags}, }) - d["private_user_data"] = private_user_data + + account_data = account_data_by_room.get(event.room_id, {}) + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + + d["account_data"] = account_data_events except: logger.exception("Failed to get snapshot") @@ -456,9 +482,17 @@ class MessageHandler(BaseHandler): consumeErrors=True ).addErrback(unwrapFirstError) + account_data_events = [] + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + ret = { "rooms": rooms_ret, "presence": presence, + "account_data": account_data_events, "receipts": receipt, "end": now_token.to_string(), } @@ -498,14 +532,22 @@ class MessageHandler(BaseHandler): user_id, room_id, pagin_config, membership, member_event_id, is_guest ) - private_user_data = [] + account_data_events = [] tags = yield self.store.get_tags_for_room(user_id, room_id) if tags: - private_user_data.append({ + account_data_events.append({ "type": "m.tag", "content": {"tags": tags}, }) - result["private_user_data"] = private_user_data + + account_data = yield self.store.get_account_data_for_room(user_id, room_id) + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + + result["account_data"] = account_data_events defer.returnValue(result) @@ -588,23 +630,28 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def get_presence(): - states = {} - if not is_guest: - states = yield presence_handler.get_states( - target_users=[UserID.from_string(m.user_id) for m in room_members], - auth_user=auth_user, - as_event=True, - check_auth=False, - ) + states = yield presence_handler.get_states( + target_users=[UserID.from_string(m.user_id) for m in room_members], + auth_user=auth_user, + as_event=True, + check_auth=False, + ) defer.returnValue(states.values()) - receipts_handler = self.hs.get_handlers().receipts_handler + @defer.inlineCallbacks + def get_receipts(): + receipts_handler = self.hs.get_handlers().receipts_handler + receipts = yield receipts_handler.get_receipts_for_room( + room_id, + now_token.receipt_key + ) + defer.returnValue(receipts) presence, receipts, (messages, token) = yield defer.gatherResults( [ get_presence(), - receipts_handler.get_receipts_for_room(room_id, now_token.receipt_key), + get_receipts(), self.store.get_recent_events_for_room( room_id, limit=limit, diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index aca65096fc..63d6f30a7b 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -62,6 +62,14 @@ def partitionbool(l, func): return ret.get(True, []), ret.get(False, []) +def user_presence_changed(distributor, user, statuscache): + return distributor.fire("user_presence_changed", user, statuscache) + + +def collect_presencelike_data(distributor, user, content): + return distributor.fire("collect_presencelike_data", user, content) + + class PresenceHandler(BaseHandler): STATE_LEVELS = { @@ -361,9 +369,7 @@ class PresenceHandler(BaseHandler): yield self.store.set_presence_state( target_user.localpart, state_to_store ) - yield self.distributor.fire( - "collect_presencelike_data", target_user, state - ) + yield collect_presencelike_data(self.distributor, target_user, state) if now_level > was_level: state["last_active"] = self.clock.time_msec() @@ -467,7 +473,7 @@ class PresenceHandler(BaseHandler): ) @defer.inlineCallbacks - def send_invite(self, observer_user, observed_user): + def send_presence_invite(self, observer_user, observed_user): """Request the presence of a local or remote user for a local user""" if not self.hs.is_mine(observer_user): raise SynapseError(400, "User is not hosted on this Home Server") @@ -878,7 +884,7 @@ class PresenceHandler(BaseHandler): room_ids=room_ids, statuscache=statuscache, ) - yield self.distributor.fire("user_presence_changed", user, statuscache) + yield user_presence_changed(self.distributor, user, statuscache) @defer.inlineCallbacks def incoming_presence(self, origin, content): @@ -1116,9 +1122,7 @@ class PresenceHandler(BaseHandler): self._user_cachemap[user].get_state()["last_active"] ) - yield self.distributor.fire( - "collect_presencelike_data", user, state - ) + yield collect_presencelike_data(self.distributor, user, state) if "last_active" in state: state = dict(state) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 799faffe53..576c6f09b4 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -28,6 +28,14 @@ import logging logger = logging.getLogger(__name__) +def changed_presencelike_data(distributor, user, state): + return distributor.fire("changed_presencelike_data", user, state) + + +def collect_presencelike_data(distributor, user, content): + return distributor.fire("collect_presencelike_data", user, content) + + class ProfileHandler(BaseHandler): def __init__(self, hs): @@ -95,11 +103,9 @@ class ProfileHandler(BaseHandler): target_user.localpart, new_displayname ) - yield self.distributor.fire( - "changed_presencelike_data", target_user, { - "displayname": new_displayname, - } - ) + yield changed_presencelike_data(self.distributor, target_user, { + "displayname": new_displayname, + }) yield self._update_join_states(target_user) @@ -144,11 +150,9 @@ class ProfileHandler(BaseHandler): target_user.localpart, new_avatar_url ) - yield self.distributor.fire( - "changed_presencelike_data", target_user, { - "avatar_url": new_avatar_url, - } - ) + yield changed_presencelike_data(self.distributor, target_user, { + "avatar_url": new_avatar_url, + }) yield self._update_join_states(target_user) @@ -208,9 +212,7 @@ class ProfileHandler(BaseHandler): "membership": Membership.JOIN, } - yield self.distributor.fire( - "collect_presencelike_data", user, content - ) + yield collect_presencelike_data(self.distributor, user, content) msg_handler = self.hs.get_handlers().message_handler try: diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 493a087031..a037da0f70 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -31,6 +31,10 @@ import urllib logger = logging.getLogger(__name__) +def registered_user(distributor, user): + return distributor.fire("registered_user", user) + + class RegistrationHandler(BaseHandler): def __init__(self, hs): @@ -38,6 +42,7 @@ class RegistrationHandler(BaseHandler): self.distributor = hs.get_distributor() self.distributor.declare("registered_user") + self.captch_client = CaptchaServerHttpClient(hs) @defer.inlineCallbacks def check_username(self, localpart): @@ -98,7 +103,7 @@ class RegistrationHandler(BaseHandler): password_hash=password_hash ) - yield self.distributor.fire("registered_user", user) + yield registered_user(self.distributor, user) else: # autogen a random user ID attempts = 0 @@ -117,7 +122,7 @@ class RegistrationHandler(BaseHandler): token=token, password_hash=password_hash) - self.distributor.fire("registered_user", user) + yield registered_user(self.distributor, user) except SynapseError: # if user id is taken, just generate another user_id = None @@ -167,7 +172,7 @@ class RegistrationHandler(BaseHandler): token=token, password_hash="" ) - self.distributor.fire("registered_user", user) + registered_user(self.distributor, user) defer.returnValue((user_id, token)) @defer.inlineCallbacks @@ -215,7 +220,7 @@ class RegistrationHandler(BaseHandler): token=token, password_hash=None ) - yield self.distributor.fire("registered_user", user) + yield registered_user(self.distributor, user) except Exception, e: yield self.store.add_access_token_to_user(user_id, token) # Ignore Registration errors @@ -302,10 +307,7 @@ class RegistrationHandler(BaseHandler): """ Used only by c/s api v1 """ - # TODO: get this from the homeserver rather than creating a new one for - # each request - client = CaptchaServerHttpClient(self.hs) - data = yield client.post_urlencoded_get_raw( + data = yield self.captcha_client.post_urlencoded_get_raw( "http://www.google.com:80/recaptcha/api/verify", args={ 'privatekey': private_key, diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 3f04752581..116a998c42 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -41,6 +41,18 @@ logger = logging.getLogger(__name__) id_server_scheme = "https://" +def collect_presencelike_data(distributor, user, content): + return distributor.fire("collect_presencelike_data", user, content) + + +def user_left_room(distributor, user, room_id): + return distributor.fire("user_left_room", user=user, room_id=room_id) + + +def user_joined_room(distributor, user, room_id): + return distributor.fire("user_joined_room", user=user, room_id=room_id) + + class RoomCreationHandler(BaseHandler): PRESETS_DICT = { @@ -438,9 +450,7 @@ class RoomMemberHandler(BaseHandler): if prev_state and prev_state.membership == Membership.JOIN: user = UserID.from_string(event.user_id) - self.distributor.fire( - "user_left_room", user=user, room_id=event.room_id - ) + user_left_room(self.distributor, user, event.room_id) defer.returnValue({"room_id": room_id}) @@ -458,9 +468,7 @@ class RoomMemberHandler(BaseHandler): raise SynapseError(404, "No known servers") # If event doesn't include a display name, add one. - yield self.distributor.fire( - "collect_presencelike_data", joinee, content - ) + yield collect_presencelike_data(self.distributor, joinee, content) content.update({"membership": Membership.JOIN}) builder = self.event_builder_factory.new({ @@ -517,10 +525,13 @@ class RoomMemberHandler(BaseHandler): do_auth=do_auth, ) - user = UserID.from_string(event.user_id) - yield self.distributor.fire( - "user_joined_room", user=user, room_id=room_id - ) + prev_state = context.current_state.get((event.type, event.state_key)) + if not prev_state or prev_state.membership != Membership.JOIN: + # Only fire user_joined_room if the user has acutally joined the + # room. Don't bother if the user is just changing their profile + # info. + user = UserID.from_string(event.user_id) + yield user_joined_room(self.distributor, user, room_id) @defer.inlineCallbacks def get_inviter(self, event): @@ -743,6 +754,9 @@ class RoomMemberHandler(BaseHandler): ) defer.returnValue((token, public_key, key_validity_url, display_name)) + def forget(self, user, room_id): + self.store.forget(user.to_string(), room_id) + class RoomListHandler(BaseHandler): diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index b7545c111f..bc79564287 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -17,13 +17,14 @@ from twisted.internet import defer from ._base import BaseHandler -from synapse.api.constants import Membership +from synapse.api.constants import Membership, EventTypes from synapse.api.filtering import Filter from synapse.api.errors import SynapseError from synapse.events.utils import serialize_event from unpaddedbase64 import decode_base64, encode_base64 +import itertools import logging @@ -79,6 +80,9 @@ class SearchHandler(BaseHandler): # What to order results by (impacts whether pagination can be doen) order_by = room_cat.get("order_by", "rank") + # Return the current state of the rooms? + include_state = room_cat.get("include_state", False) + # Include context around each event? event_context = room_cat.get( "event_context", None @@ -96,6 +100,10 @@ class SearchHandler(BaseHandler): after_limit = int(event_context.get( "after_limit", 5 )) + + # Return the historic display name and avatar for the senders + # of the events? + include_profile = bool(event_context.get("include_profile", False)) except KeyError: raise SynapseError(400, "Invalid search query") @@ -123,6 +131,17 @@ class SearchHandler(BaseHandler): if batch_group == "room_id": room_ids.intersection_update({batch_group_key}) + if not room_ids: + defer.returnValue({ + "search_categories": { + "room_events": { + "results": [], + "count": 0, + "highlights": [], + } + } + }) + rank_map = {} # event_id -> rank of event allowed_events = [] room_groups = {} # Holds result of grouping by room, if applicable @@ -131,11 +150,18 @@ class SearchHandler(BaseHandler): # Holds the next_batch for the entire result set if one of those exists global_next_batch = None + highlights = set() + if order_by == "rank": - results = yield self.store.search_msgs( + search_result = yield self.store.search_msgs( room_ids, search_term, keys ) + if search_result["highlights"]: + highlights.update(search_result["highlights"]) + + results = search_result["results"] + results_map = {r["event"].event_id: r for r in results} rank_map.update({r["event"].event_id: r["rank"] for r in results}) @@ -163,80 +189,76 @@ class SearchHandler(BaseHandler): s["results"].append(e.event_id) elif order_by == "recent": - # In this case we specifically loop through each room as the given - # limit applies to each room, rather than a global list. - # This is not necessarilly a good idea. - for room_id in room_ids: - room_events = [] - if batch_group == "room_id" and batch_group_key == room_id: - pagination_token = batch_token - else: - pagination_token = None - i = 0 - - # We keep looping and we keep filtering until we reach the limit - # or we run out of things. - # But only go around 5 times since otherwise synapse will be sad. - while len(room_events) < search_filter.limit() and i < 5: - i += 1 - results = yield self.store.search_room( - room_id, search_term, keys, search_filter.limit() * 2, - pagination_token=pagination_token, - ) + room_events = [] + i = 0 + + pagination_token = batch_token + + # We keep looping and we keep filtering until we reach the limit + # or we run out of things. + # But only go around 5 times since otherwise synapse will be sad. + while len(room_events) < search_filter.limit() and i < 5: + i += 1 + search_result = yield self.store.search_rooms( + room_ids, search_term, keys, search_filter.limit() * 2, + pagination_token=pagination_token, + ) - results_map = {r["event"].event_id: r for r in results} + if search_result["highlights"]: + highlights.update(search_result["highlights"]) - rank_map.update({r["event"].event_id: r["rank"] for r in results}) + results = search_result["results"] - filtered_events = search_filter.filter([ - r["event"] for r in results - ]) + results_map = {r["event"].event_id: r for r in results} - events = yield self._filter_events_for_client( - user.to_string(), filtered_events - ) + rank_map.update({r["event"].event_id: r["rank"] for r in results}) - room_events.extend(events) - room_events = room_events[:search_filter.limit()] + filtered_events = search_filter.filter([ + r["event"] for r in results + ]) - if len(results) < search_filter.limit() * 2: - pagination_token = None - break - else: - pagination_token = results[-1]["pagination_token"] - - if room_events: - res = results_map[room_events[-1].event_id] - pagination_token = res["pagination_token"] - - group = room_groups.setdefault(room_id, {}) - if pagination_token: - next_batch = encode_base64("%s\n%s\n%s" % ( - "room_id", room_id, pagination_token - )) - group["next_batch"] = next_batch - - if batch_token: - global_next_batch = next_batch - - group["results"] = [e.event_id for e in room_events] - group["order"] = max( - e.origin_server_ts/1000 for e in room_events - if hasattr(e, "origin_server_ts") - ) + events = yield self._filter_events_for_client( + user.to_string(), filtered_events + ) - allowed_events.extend(room_events) + room_events.extend(events) + room_events = room_events[:search_filter.limit()] - # Normalize the group orders - if room_groups: - if len(room_groups) > 1: - mx = max(g["order"] for g in room_groups.values()) - mn = min(g["order"] for g in room_groups.values()) + if len(results) < search_filter.limit() * 2: + pagination_token = None + break + else: + pagination_token = results[-1]["pagination_token"] - for g in room_groups.values(): - g["order"] = (g["order"] - mn) * 1.0 / (mx - mn) + for event in room_events: + group = room_groups.setdefault(event.room_id, { + "results": [], + }) + group["results"].append(event.event_id) + + if room_events and len(room_events) >= search_filter.limit(): + last_event_id = room_events[-1].event_id + pagination_token = results_map[last_event_id]["pagination_token"] + + # We want to respect the given batch group and group keys so + # that if people blindly use the top level `next_batch` token + # it returns more from the same group (if applicable) rather + # than reverting to searching all results again. + if batch_group and batch_group_key: + global_next_batch = encode_base64("%s\n%s\n%s" % ( + batch_group, batch_group_key, pagination_token + )) else: - room_groups.values()[0]["order"] = 1 + global_next_batch = encode_base64("%s\n%s\n%s" % ( + "all", "", pagination_token + )) + + for room_id, group in room_groups.items(): + group["next_batch"] = encode_base64("%s\n%s\n%s" % ( + "room_id", room_id, pagination_token + )) + + allowed_events.extend(room_events) else: # We should never get here due to the guard earlier. @@ -269,6 +291,33 @@ class SearchHandler(BaseHandler): "room_key", res["end"] ).to_string() + if include_profile: + senders = set( + ev.sender + for ev in itertools.chain( + res["events_before"], [event], res["events_after"] + ) + ) + + if res["events_after"]: + last_event_id = res["events_after"][-1].event_id + else: + last_event_id = event.event_id + + state = yield self.store.get_state_for_event( + last_event_id, + types=[(EventTypes.Member, sender) for sender in senders] + ) + + res["profile_info"] = { + s.state_key: { + "displayname": s.content.get("displayname", None), + "avatar_url": s.content.get("avatar_url", None), + } + for s in state.values() + if s.type == EventTypes.Member and s.state_key in senders + } + contexts[event.event_id] = res else: contexts = {} @@ -287,22 +336,39 @@ class SearchHandler(BaseHandler): for e in context["events_after"] ] - results = { - e.event_id: { + state_results = {} + if include_state: + rooms = set(e.room_id for e in allowed_events) + for room_id in rooms: + state = yield self.state_handler.get_current_state(room_id) + state_results[room_id] = state.values() + + state_results.values() + + # We're now about to serialize the events. We should not make any + # blocking calls after this. Otherwise the 'age' will be wrong + + results = [ + { "rank": rank_map[e.event_id], "result": serialize_event(e, time_now), "context": contexts.get(e.event_id, {}), } for e in allowed_events - } - - logger.info("Found %d results", len(results)) + ] rooms_cat_res = { "results": results, - "count": len(results) + "count": len(results), + "highlights": list(highlights), } + if state_results: + rooms_cat_res["state"] = { + room_id: [serialize_event(e, time_now) for e in state] + for room_id, state in state_results.items() + } + if room_groups and "room_id" in group_keys: rooms_cat_res.setdefault("groups", {})["room_id"] = room_groups diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 6dc9d0fb92..24c2b2fad6 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -51,7 +51,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [ "timeline", # TimelineBatch "state", # dict[(str, str), FrozenEvent] "ephemeral", - "private_user_data", + "account_data", ])): __slots__ = [] @@ -63,7 +63,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [ self.timeline or self.state or self.ephemeral - or self.private_user_data + or self.account_data ) @@ -71,7 +71,7 @@ class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [ "room_id", # str "timeline", # TimelineBatch "state", # dict[(str, str), FrozenEvent] - "private_user_data", + "account_data", ])): __slots__ = [] @@ -82,7 +82,7 @@ class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [ return bool( self.timeline or self.state - or self.private_user_data + or self.account_data ) @@ -100,6 +100,7 @@ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [ class SyncResult(collections.namedtuple("SyncResult", [ "next_batch", # Token for the next sync "presence", # List of presence events for the user. + "account_data", # List of account_data events for the user. "joined", # JoinedSyncResult for each joined room. "invited", # InvitedSyncResult for each invited room. "archived", # ArchivedSyncResult for each archived room. @@ -185,13 +186,19 @@ class SyncHandler(BaseHandler): pagination_config=pagination_config.get_source_config("presence"), key=None ) + + membership_list = (Membership.INVITE, Membership.JOIN) + if sync_config.filter.include_leave: + membership_list += (Membership.LEAVE, Membership.BAN) + room_list = yield self.store.get_rooms_for_user_where_membership_is( user_id=sync_config.user.to_string(), - membership_list=( - Membership.INVITE, - Membership.JOIN, - Membership.LEAVE, - Membership.BAN + membership_list=membership_list + ) + + account_data, account_data_by_room = ( + yield self.store.get_account_data_for_user( + sync_config.user.to_string() ) ) @@ -211,6 +218,7 @@ class SyncHandler(BaseHandler): timeline_since_token=timeline_since_token, ephemeral_by_room=ephemeral_by_room, tags_by_room=tags_by_room, + account_data_by_room=account_data_by_room, ) joined.append(room_sync) elif event.membership == Membership.INVITE: @@ -230,11 +238,13 @@ class SyncHandler(BaseHandler): leave_token=leave_token, timeline_since_token=timeline_since_token, tags_by_room=tags_by_room, + account_data_by_room=account_data_by_room, ) archived.append(room_sync) defer.returnValue(SyncResult( presence=presence, + account_data=self.account_data_for_user(account_data), joined=joined, invited=invited, archived=archived, @@ -244,7 +254,8 @@ class SyncHandler(BaseHandler): @defer.inlineCallbacks def full_state_sync_for_joined_room(self, room_id, sync_config, now_token, timeline_since_token, - ephemeral_by_room, tags_by_room): + ephemeral_by_room, tags_by_room, + account_data_by_room): """Sync a room for a client which is starting without any state Returns: A Deferred JoinedSyncResult. @@ -261,20 +272,39 @@ class SyncHandler(BaseHandler): timeline=batch, state=current_state, ephemeral=ephemeral_by_room.get(room_id, []), - private_user_data=self.private_user_data_for_room( - room_id, tags_by_room + account_data=self.account_data_for_room( + room_id, tags_by_room, account_data_by_room ), )) - def private_user_data_for_room(self, room_id, tags_by_room): - private_user_data = [] + def account_data_for_user(self, account_data): + account_data_events = [] + + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + + return account_data_events + + def account_data_for_room(self, room_id, tags_by_room, account_data_by_room): + account_data_events = [] tags = tags_by_room.get(room_id) if tags is not None: - private_user_data.append({ + account_data_events.append({ "type": "m.tag", "content": {"tags": tags}, }) - return private_user_data + + account_data = account_data_by_room.get(room_id, {}) + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + + return account_data_events @defer.inlineCallbacks def ephemeral_by_room(self, sync_config, now_token, since_token=None): @@ -341,7 +371,8 @@ class SyncHandler(BaseHandler): @defer.inlineCallbacks def full_state_sync_for_archived_room(self, room_id, sync_config, leave_event_id, leave_token, - timeline_since_token, tags_by_room): + timeline_since_token, tags_by_room, + account_data_by_room): """Sync a room for a client which is starting without any state Returns: A Deferred JoinedSyncResult. @@ -357,8 +388,8 @@ class SyncHandler(BaseHandler): room_id=room_id, timeline=batch, state=leave_state, - private_user_data=self.private_user_data_for_room( - room_id, tags_by_room + account_data=self.account_data_for_room( + room_id, tags_by_room, account_data_by_room ), )) @@ -412,7 +443,14 @@ class SyncHandler(BaseHandler): tags_by_room = yield self.store.get_updated_tags( sync_config.user.to_string(), - since_token.private_user_data_key, + since_token.account_data_key, + ) + + account_data, account_data_by_room = ( + yield self.store.get_updated_account_data_for_user( + sync_config.user.to_string(), + since_token.account_data_key, + ) ) joined = [] @@ -468,8 +506,8 @@ class SyncHandler(BaseHandler): ), state=state, ephemeral=ephemeral_by_room.get(room_id, []), - private_user_data=self.private_user_data_for_room( - room_id, tags_by_room + account_data=self.account_data_for_room( + room_id, tags_by_room, account_data_by_room ), ) logger.debug("Result for room %s: %r", room_id, room_sync) @@ -492,14 +530,15 @@ class SyncHandler(BaseHandler): for room_id in joined_room_ids: room_sync = yield self.incremental_sync_with_gap_for_room( room_id, sync_config, since_token, now_token, - ephemeral_by_room, tags_by_room + ephemeral_by_room, tags_by_room, account_data_by_room ) if room_sync: joined.append(room_sync) for leave_event in leave_events: room_sync = yield self.incremental_sync_for_archived_room( - sync_config, leave_event, since_token, tags_by_room + sync_config, leave_event, since_token, tags_by_room, + account_data_by_room ) archived.append(room_sync) @@ -510,6 +549,7 @@ class SyncHandler(BaseHandler): defer.returnValue(SyncResult( presence=presence, + account_data=self.account_data_for_user(account_data), joined=joined, invited=invited, archived=archived, @@ -566,7 +606,8 @@ class SyncHandler(BaseHandler): @defer.inlineCallbacks def incremental_sync_with_gap_for_room(self, room_id, sync_config, since_token, now_token, - ephemeral_by_room, tags_by_room): + ephemeral_by_room, tags_by_room, + account_data_by_room): """ Get the incremental delta needed to bring the client up to date for the room. Gives the client the most recent events and the changes to state. @@ -605,8 +646,8 @@ class SyncHandler(BaseHandler): timeline=batch, state=state, ephemeral=ephemeral_by_room.get(room_id, []), - private_user_data=self.private_user_data_for_room( - room_id, tags_by_room + account_data=self.account_data_for_room( + room_id, tags_by_room, account_data_by_room ), ) @@ -616,7 +657,8 @@ class SyncHandler(BaseHandler): @defer.inlineCallbacks def incremental_sync_for_archived_room(self, sync_config, leave_event, - since_token, tags_by_room): + since_token, tags_by_room, + account_data_by_room): """ Get the incremental delta needed to bring the client up to date for the archived room. Returns: @@ -653,8 +695,8 @@ class SyncHandler(BaseHandler): room_id=leave_event.room_id, timeline=batch, state=state_events_delta, - private_user_data=self.private_user_data_for_room( - leave_event.room_id, tags_by_room + account_data=self.account_data_for_room( + leave_event.room_id, tags_by_room, account_data_by_room ), ) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 6e53538a52..b7b7c2cce8 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -56,7 +56,8 @@ incoming_responses_counter = metrics.register_counter( ) -MAX_RETRIES = 4 +MAX_LONG_RETRIES = 10 +MAX_SHORT_RETRIES = 3 class MatrixFederationEndpointFactory(object): @@ -103,7 +104,7 @@ class MatrixFederationHttpClient(object): def _create_request(self, destination, method, path_bytes, body_callback, headers_dict={}, param_bytes=b"", query_bytes=b"", retry_on_dns_fail=True, - timeout=None): + timeout=None, long_retries=False): """ Creates and sends a request to the given url """ headers_dict[b"User-Agent"] = [self.version_string] @@ -123,7 +124,10 @@ class MatrixFederationHttpClient(object): # XXX: Would be much nicer to retry only at the transaction-layer # (once we have reliable transactions in place) - retries_left = MAX_RETRIES + if long_retries: + retries_left = MAX_LONG_RETRIES + else: + retries_left = MAX_SHORT_RETRIES http_url_bytes = urlparse.urlunparse( ("", "", path_bytes, param_bytes, query_bytes, "") @@ -184,8 +188,15 @@ class MatrixFederationHttpClient(object): ) if retries_left and not timeout: - delay = 5 ** (MAX_RETRIES + 1 - retries_left) - delay *= random.uniform(0.8, 1.4) + if long_retries: + delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left) + delay = min(delay, 60) + delay *= random.uniform(0.8, 1.4) + else: + delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left) + delay = min(delay, 2) + delay *= random.uniform(0.8, 1.4) + yield sleep(delay) retries_left -= 1 else: @@ -236,7 +247,8 @@ class MatrixFederationHttpClient(object): headers_dict[b"Authorization"] = auth_headers @defer.inlineCallbacks - def put_json(self, destination, path, data={}, json_data_callback=None): + def put_json(self, destination, path, data={}, json_data_callback=None, + long_retries=False): """ Sends the specifed json data using PUT Args: @@ -247,6 +259,8 @@ class MatrixFederationHttpClient(object): the request body. This will be encoded as JSON. json_data_callback (callable): A callable returning the dict to use as the request body. + long_retries (bool): A boolean that indicates whether we should + retry for a short or long time. Returns: Deferred: Succeeds when we get a 2xx HTTP response. The result @@ -272,6 +286,7 @@ class MatrixFederationHttpClient(object): path.encode("ascii"), body_callback=body_callback, headers_dict={"Content-Type": ["application/json"]}, + long_retries=long_retries, ) if 200 <= response.code < 300: @@ -287,7 +302,7 @@ class MatrixFederationHttpClient(object): defer.returnValue(json.loads(body)) @defer.inlineCallbacks - def post_json(self, destination, path, data={}): + def post_json(self, destination, path, data={}, long_retries=True): """ Sends the specifed json data using POST Args: @@ -296,6 +311,8 @@ class MatrixFederationHttpClient(object): path (str): The HTTP path. data (dict): A dict containing the data that will be used as the request body. This will be encoded as JSON. + long_retries (bool): A boolean that indicates whether we should + retry for a short or long time. Returns: Deferred: Succeeds when we get a 2xx HTTP response. The result @@ -315,6 +332,7 @@ class MatrixFederationHttpClient(object): path.encode("ascii"), body_callback=body_callback, headers_dict={"Content-Type": ["application/json"]}, + long_retries=True, ) if 200 <= response.code < 300: @@ -490,6 +508,9 @@ class _JsonProducer(object): def stopProducing(self): pass + def resumeProducing(self): + pass + def _flatten_response_never_received(e): if hasattr(e, "reasons"): diff --git a/synapse/http/server.py b/synapse/http/server.py index 50feea6f1c..c44bdfc888 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -53,6 +53,23 @@ response_timer = metrics.register_distribution( labels=["method", "servlet"] ) +response_ru_utime = metrics.register_distribution( + "response_ru_utime", labels=["method", "servlet"] +) + +response_ru_stime = metrics.register_distribution( + "response_ru_stime", labels=["method", "servlet"] +) + +response_db_txn_count = metrics.register_distribution( + "response_db_txn_count", labels=["method", "servlet"] +) + +response_db_txn_duration = metrics.register_distribution( + "response_db_txn_duration", labels=["method", "servlet"] +) + + _next_request_id = 0 @@ -120,7 +137,7 @@ class HttpServer(object): """ Interface for registering callbacks on a HTTP server """ - def register_path(self, method, path_pattern, callback): + def register_paths(self, method, path_patterns, callback): """ Register a callback that gets fired if we receive a http request with the given method for a path that matches the given regex. @@ -129,7 +146,7 @@ class HttpServer(object): Args: method (str): The method to listen to. - path_pattern (str): The regex used to match requests. + path_patterns (list<SRE_Pattern>): The regex used to match requests. callback (function): The function to fire if we receive a matched request. The first argument will be the request object and subsequent arguments will be any matched groups from the regex. @@ -165,10 +182,11 @@ class JsonResource(HttpServer, resource.Resource): self.version_string = hs.version_string self.hs = hs - def register_path(self, method, path_pattern, callback): - self.path_regexs.setdefault(method, []).append( - self._PathEntry(path_pattern, callback) - ) + def register_paths(self, method, path_patterns, callback): + for path_pattern in path_patterns: + self.path_regexs.setdefault(method, []).append( + self._PathEntry(path_pattern, callback) + ) def render(self, request): """ This gets called by twisted every time someone sends us a request. @@ -220,6 +238,21 @@ class JsonResource(HttpServer, resource.Resource): self.clock.time_msec() - start, request.method, servlet_classname ) + try: + context = LoggingContext.current_context() + ru_utime, ru_stime = context.get_resource_usage() + + response_ru_utime.inc_by(ru_utime, request.method, servlet_classname) + response_ru_stime.inc_by(ru_stime, request.method, servlet_classname) + response_db_txn_count.inc_by( + context.db_txn_count, request.method, servlet_classname + ) + response_db_txn_duration.inc_by( + context.db_txn_duration, request.method, servlet_classname + ) + except: + pass + return # Huh. No one wanted to handle that? Fiiiiiine. Send 400. diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 9cda17fcf8..32b6d6cd72 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -19,7 +19,6 @@ from synapse.api.errors import SynapseError import logging - logger = logging.getLogger(__name__) @@ -102,12 +101,13 @@ class RestServlet(object): def register(self, http_server): """ Register this servlet with the given HTTP server. """ - if hasattr(self, "PATTERN"): - pattern = self.PATTERN + if hasattr(self, "PATTERNS"): + patterns = self.PATTERNS for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"): if hasattr(self, "on_%s" % (method,)): method_handler = getattr(self, "on_%s" % (method,)) - http_server.register_path(method, pattern, method_handler) + http_server.register_paths(method, patterns, method_handler) + else: raise NotImplementedError("RestServlet must register something.") diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 0e0c61dec8..e7c964bcd2 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -16,14 +16,12 @@ from twisted.internet import defer from synapse.streams.config import PaginationConfig -from synapse.types import StreamToken, UserID +from synapse.types import StreamToken import synapse.util.async -import baserules +import push_rule_evaluator as push_rule_evaluator import logging -import simplejson as json -import re import random logger = logging.getLogger(__name__) @@ -33,9 +31,6 @@ class Pusher(object): INITIAL_BACKOFF = 1000 MAX_BACKOFF = 60 * 60 * 1000 GIVE_UP_AFTER = 24 * 60 * 60 * 1000 - DEFAULT_ACTIONS = ['dont_notify'] - - INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") def __init__(self, _hs, profile_tag, user_name, app_id, app_display_name, device_display_name, pushkey, pushkey_ts, @@ -63,161 +58,6 @@ class Pusher(object): self.has_unread = True @defer.inlineCallbacks - def _actions_for_event(self, ev): - """ - This should take into account notification settings that the user - has configured both globally and per-room when we have the ability - to do such things. - """ - if ev['user_id'] == self.user_name: - # let's assume you probably know about messages you sent yourself - defer.returnValue(['dont_notify']) - - rawrules = yield self.store.get_push_rules_for_user(self.user_name) - - rules = [] - for rawrule in rawrules: - rule = dict(rawrule) - rule['conditions'] = json.loads(rawrule['conditions']) - rule['actions'] = json.loads(rawrule['actions']) - rules.append(rule) - - enabled_map = yield self.store.get_push_rules_enabled_for_user(self.user_name) - - user = UserID.from_string(self.user_name) - - rules = baserules.list_with_base_rules(rules, user) - - room_id = ev['room_id'] - - # get *our* member event for display name matching - my_display_name = None - our_member_event = yield self.store.get_current_state( - room_id=room_id, - event_type='m.room.member', - state_key=self.user_name, - ) - if our_member_event: - my_display_name = our_member_event[0].content.get("displayname") - - room_members = yield self.store.get_users_in_room(room_id) - room_member_count = len(room_members) - - for r in rules: - if r['rule_id'] in enabled_map: - r['enabled'] = enabled_map[r['rule_id']] - elif 'enabled' not in r: - r['enabled'] = True - if not r['enabled']: - continue - matches = True - - conditions = r['conditions'] - actions = r['actions'] - - for c in conditions: - matches &= self._event_fulfills_condition( - ev, c, display_name=my_display_name, - room_member_count=room_member_count - ) - logger.debug( - "Rule %s %s", - r['rule_id'], "matches" if matches else "doesn't match" - ) - # ignore rules with no actions (we have an explict 'dont_notify') - if len(actions) == 0: - logger.warn( - "Ignoring rule id %s with no actions for user %s", - r['rule_id'], self.user_name - ) - continue - if matches: - logger.info( - "%s matches for user %s, event %s", - r['rule_id'], self.user_name, ev['event_id'] - ) - defer.returnValue(actions) - - logger.info( - "No rules match for user %s, event %s", - self.user_name, ev['event_id'] - ) - defer.returnValue(Pusher.DEFAULT_ACTIONS) - - @staticmethod - def _glob_to_regexp(glob): - r = re.escape(glob) - r = re.sub(r'\\\*', r'.*?', r) - r = re.sub(r'\\\?', r'.', r) - - # handle [abc], [a-z] and [!a-z] style ranges. - r = re.sub(r'\\\[(\\\!|)(.*)\\\]', - lambda x: ('[%s%s]' % (x.group(1) and '^' or '', - re.sub(r'\\\-', '-', x.group(2)))), r) - return r - - def _event_fulfills_condition(self, ev, condition, display_name, room_member_count): - if condition['kind'] == 'event_match': - if 'pattern' not in condition: - logger.warn("event_match condition with no pattern") - return False - # XXX: optimisation: cache our pattern regexps - if condition['key'] == 'content.body': - r = r'\b%s\b' % self._glob_to_regexp(condition['pattern']) - else: - r = r'^%s$' % self._glob_to_regexp(condition['pattern']) - val = _value_for_dotted_key(condition['key'], ev) - if val is None: - return False - return re.search(r, val, flags=re.IGNORECASE) is not None - - elif condition['kind'] == 'device': - if 'profile_tag' not in condition: - return True - return condition['profile_tag'] == self.profile_tag - - elif condition['kind'] == 'contains_display_name': - # This is special because display names can be different - # between rooms and so you can't really hard code it in a rule. - # Optimisation: we should cache these names and update them from - # the event stream. - if 'content' not in ev or 'body' not in ev['content']: - return False - if not display_name: - return False - return re.search( - r"\b%s\b" % re.escape(display_name), ev['content']['body'], - flags=re.IGNORECASE - ) is not None - - elif condition['kind'] == 'room_member_count': - if 'is' not in condition: - return False - m = Pusher.INEQUALITY_EXPR.match(condition['is']) - if not m: - return False - ineq = m.group(1) - rhs = m.group(2) - if not rhs.isdigit(): - return False - rhs = int(rhs) - - if ineq == '' or ineq == '==': - return room_member_count == rhs - elif ineq == '<': - return room_member_count < rhs - elif ineq == '>': - return room_member_count > rhs - elif ineq == '>=': - return room_member_count >= rhs - elif ineq == '<=': - return room_member_count <= rhs - else: - return False - else: - return True - - @defer.inlineCallbacks def get_context_for_event(self, ev): name_aliases = yield self.store.get_room_name_and_aliases( ev['room_id'] @@ -308,8 +148,14 @@ class Pusher(object): return processed = False - actions = yield self._actions_for_event(single_event) - tweaks = _tweaks_for_actions(actions) + + rule_evaluator = yield \ + push_rule_evaluator.evaluator_for_user_name_and_profile_tag( + self.user_name, self.profile_tag, single_event['room_id'], self.store + ) + + actions = yield rule_evaluator.actions_for_event(single_event) + tweaks = rule_evaluator.tweaks_for_actions(actions) if len(actions) == 0: logger.warn("Empty actions! Using default action.") @@ -448,27 +294,6 @@ class Pusher(object): self.has_unread = False -def _value_for_dotted_key(dotted_key, event): - parts = dotted_key.split(".") - val = event - while len(parts) > 0: - if parts[0] not in val: - return None - val = val[parts[0]] - parts = parts[1:] - return val - - -def _tweaks_for_actions(actions): - tweaks = {} - for a in actions: - if not isinstance(a, dict): - continue - if 'set_tweak' in a and 'value' in a: - tweaks[a['set_tweak']] = a['value'] - return tweaks - - class PusherConfigException(Exception): def __init__(self, msg): super(PusherConfigException, self).__init__(msg) diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 1f015a7f2e..7f76382a17 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -247,6 +247,7 @@ def make_base_append_underride_rules(user): }, { 'rule_id': 'global/underride/.m.rule.message', + 'enabled': False, 'conditions': [ { 'kind': 'event_match', diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index a02fed57b4..5160775e59 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -14,7 +14,6 @@ # limitations under the License. from synapse.push import Pusher, PusherConfigException -from synapse.http.client import SimpleHttpClient from twisted.internet import defer @@ -46,7 +45,7 @@ class HttpPusher(Pusher): "'url' required in data for HTTP pusher" ) self.url = data['url'] - self.httpCli = SimpleHttpClient(self.hs) + self.http_client = _hs.get_simple_http_client() self.data_minus_url = {} self.data_minus_url.update(self.data) del self.data_minus_url['url'] @@ -107,7 +106,7 @@ class HttpPusher(Pusher): if not notification_dict: defer.returnValue([]) try: - resp = yield self.httpCli.post_json_get_json(self.url, notification_dict) + resp = yield self.http_client.post_json_get_json(self.url, notification_dict) except: logger.warn("Failed to push %s ", self.url) defer.returnValue(False) @@ -138,7 +137,7 @@ class HttpPusher(Pusher): } } try: - resp = yield self.httpCli.post_json_get_json(self.url, d) + resp = yield self.http_client.post_json_get_json(self.url, d) except: logger.exception("Failed to push %s ", self.url) defer.returnValue(False) diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py new file mode 100644 index 0000000000..92c7fd048f --- /dev/null +++ b/synapse/push/push_rule_evaluator.py @@ -0,0 +1,224 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.types import UserID + +import baserules + +import logging +import simplejson as json +import re + +logger = logging.getLogger(__name__) + + +@defer.inlineCallbacks +def evaluator_for_user_name_and_profile_tag(user_name, profile_tag, room_id, store): + rawrules = yield store.get_push_rules_for_user(user_name) + enabled_map = yield store.get_push_rules_enabled_for_user(user_name) + our_member_event = yield store.get_current_state( + room_id=room_id, + event_type='m.room.member', + state_key=user_name, + ) + + defer.returnValue(PushRuleEvaluator( + user_name, profile_tag, rawrules, enabled_map, + room_id, our_member_event, store + )) + + +class PushRuleEvaluator: + DEFAULT_ACTIONS = ['dont_notify'] + INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") + + def __init__(self, user_name, profile_tag, raw_rules, enabled_map, room_id, + our_member_event, store): + self.user_name = user_name + self.profile_tag = profile_tag + self.room_id = room_id + self.our_member_event = our_member_event + self.store = store + + rules = [] + for raw_rule in raw_rules: + rule = dict(raw_rule) + rule['conditions'] = json.loads(raw_rule['conditions']) + rule['actions'] = json.loads(raw_rule['actions']) + rules.append(rule) + + user = UserID.from_string(self.user_name) + self.rules = baserules.list_with_base_rules(rules, user) + + self.enabled_map = enabled_map + + @staticmethod + def tweaks_for_actions(actions): + tweaks = {} + for a in actions: + if not isinstance(a, dict): + continue + if 'set_tweak' in a and 'value' in a: + tweaks[a['set_tweak']] = a['value'] + return tweaks + + @defer.inlineCallbacks + def actions_for_event(self, ev): + """ + This should take into account notification settings that the user + has configured both globally and per-room when we have the ability + to do such things. + """ + if ev['user_id'] == self.user_name: + # let's assume you probably know about messages you sent yourself + defer.returnValue(['dont_notify']) + + room_id = ev['room_id'] + + # get *our* member event for display name matching + my_display_name = None + + if self.our_member_event: + my_display_name = self.our_member_event[0].content.get("displayname") + + room_members = yield self.store.get_users_in_room(room_id) + room_member_count = len(room_members) + + for r in self.rules: + if r['rule_id'] in self.enabled_map: + r['enabled'] = self.enabled_map[r['rule_id']] + elif 'enabled' not in r: + r['enabled'] = True + if not r['enabled']: + continue + matches = True + + conditions = r['conditions'] + actions = r['actions'] + + for c in conditions: + matches &= self._event_fulfills_condition( + ev, c, display_name=my_display_name, + room_member_count=room_member_count + ) + logger.debug( + "Rule %s %s", + r['rule_id'], "matches" if matches else "doesn't match" + ) + # ignore rules with no actions (we have an explict 'dont_notify') + if len(actions) == 0: + logger.warn( + "Ignoring rule id %s with no actions for user %s", + r['rule_id'], self.user_name + ) + continue + if matches: + logger.info( + "%s matches for user %s, event %s", + r['rule_id'], self.user_name, ev['event_id'] + ) + defer.returnValue(actions) + + logger.info( + "No rules match for user %s, event %s", + self.user_name, ev['event_id'] + ) + defer.returnValue(PushRuleEvaluator.DEFAULT_ACTIONS) + + @staticmethod + def _glob_to_regexp(glob): + r = re.escape(glob) + r = re.sub(r'\\\*', r'.*?', r) + r = re.sub(r'\\\?', r'.', r) + + # handle [abc], [a-z] and [!a-z] style ranges. + r = re.sub(r'\\\[(\\\!|)(.*)\\\]', + lambda x: ('[%s%s]' % (x.group(1) and '^' or '', + re.sub(r'\\\-', '-', x.group(2)))), r) + return r + + def _event_fulfills_condition(self, ev, condition, display_name, room_member_count): + if condition['kind'] == 'event_match': + if 'pattern' not in condition: + logger.warn("event_match condition with no pattern") + return False + # XXX: optimisation: cache our pattern regexps + if condition['key'] == 'content.body': + r = r'\b%s\b' % self._glob_to_regexp(condition['pattern']) + else: + r = r'^%s$' % self._glob_to_regexp(condition['pattern']) + val = _value_for_dotted_key(condition['key'], ev) + if val is None: + return False + return re.search(r, val, flags=re.IGNORECASE) is not None + + elif condition['kind'] == 'device': + if 'profile_tag' not in condition: + return True + return condition['profile_tag'] == self.profile_tag + + elif condition['kind'] == 'contains_display_name': + # This is special because display names can be different + # between rooms and so you can't really hard code it in a rule. + # Optimisation: we should cache these names and update them from + # the event stream. + if 'content' not in ev or 'body' not in ev['content']: + return False + if not display_name: + return False + return re.search( + r"\b%s\b" % re.escape(display_name), ev['content']['body'], + flags=re.IGNORECASE + ) is not None + + elif condition['kind'] == 'room_member_count': + if 'is' not in condition: + return False + m = PushRuleEvaluator.INEQUALITY_EXPR.match(condition['is']) + if not m: + return False + ineq = m.group(1) + rhs = m.group(2) + if not rhs.isdigit(): + return False + rhs = int(rhs) + + if ineq == '' or ineq == '==': + return room_member_count == rhs + elif ineq == '<': + return room_member_count < rhs + elif ineq == '>': + return room_member_count > rhs + elif ineq == '>=': + return room_member_count >= rhs + elif ineq == '<=': + return room_member_count <= rhs + else: + return False + else: + return True + + +def _value_for_dotted_key(dotted_key, event): + parts = dotted_key.split(".") + val = event + while len(parts) > 0: + if parts[0] not in val: + return None + val = val[parts[0]] + parts = parts[1:] + return val diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 1a84d94cd9..7b67e96204 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2015 OpenMarket Ltd +# Copyright 2014, 2015 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,3 +12,69 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from synapse.rest.client.v1 import ( + room, + events, + profile, + presence, + initial_sync, + directory, + voip, + admin, + pusher, + push_rule, + register as v1_register, + login as v1_login, +) + +from synapse.rest.client.v2_alpha import ( + sync, + filter, + account, + register, + auth, + receipts, + keys, + tokenrefresh, + tags, + account_data, +) + +from synapse.http.server import JsonResource + + +class ClientRestResource(JsonResource): + """A resource for version 1 of the matrix client API.""" + + def __init__(self, hs): + JsonResource.__init__(self, hs, canonical_json=False) + self.register_servlets(self, hs) + + @staticmethod + def register_servlets(client_resource, hs): + # "v1" + room.register_servlets(hs, client_resource) + events.register_servlets(hs, client_resource) + v1_register.register_servlets(hs, client_resource) + v1_login.register_servlets(hs, client_resource) + profile.register_servlets(hs, client_resource) + presence.register_servlets(hs, client_resource) + initial_sync.register_servlets(hs, client_resource) + directory.register_servlets(hs, client_resource) + voip.register_servlets(hs, client_resource) + admin.register_servlets(hs, client_resource) + pusher.register_servlets(hs, client_resource) + push_rule.register_servlets(hs, client_resource) + + # "v2" + sync.register_servlets(hs, client_resource) + filter.register_servlets(hs, client_resource) + account.register_servlets(hs, client_resource) + register.register_servlets(hs, client_resource) + auth.register_servlets(hs, client_resource) + receipts.register_servlets(hs, client_resource) + keys.register_servlets(hs, client_resource) + tokenrefresh.register_servlets(hs, client_resource) + tags.register_servlets(hs, client_resource) + account_data.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v1/__init__.py b/synapse/rest/client/v1/__init__.py index cc9b49d539..c488b10d3c 100644 --- a/synapse/rest/client/v1/__init__.py +++ b/synapse/rest/client/v1/__init__.py @@ -12,33 +12,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from . import ( - room, events, register, login, profile, presence, initial_sync, directory, - voip, admin, pusher, push_rule -) - -from synapse.http.server import JsonResource - - -class ClientV1RestResource(JsonResource): - """A resource for version 1 of the matrix client API.""" - - def __init__(self, hs): - JsonResource.__init__(self, hs, canonical_json=False) - self.register_servlets(self, hs) - - @staticmethod - def register_servlets(client_resource, hs): - room.register_servlets(hs, client_resource) - events.register_servlets(hs, client_resource) - register.register_servlets(hs, client_resource) - login.register_servlets(hs, client_resource) - profile.register_servlets(hs, client_resource) - presence.register_servlets(hs, client_resource) - initial_sync.register_servlets(hs, client_resource) - directory.register_servlets(hs, client_resource) - voip.register_servlets(hs, client_resource) - admin.register_servlets(hs, client_resource) - pusher.register_servlets(hs, client_resource) - push_rule.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index bdde43864c..886199a6da 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import AuthError, SynapseError from synapse.types import UserID -from base import ClientV1RestServlet, client_path_pattern +from base import ClientV1RestServlet, client_path_patterns import logging @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) class WhoisRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/admin/whois/(?P<user_id>[^/]*)") + PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)") @defer.inlineCallbacks def on_GET(self, request, user_id): diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py index 504a5e432f..6273ce0795 100644 --- a/synapse/rest/client/v1/base.py +++ b/synapse/rest/client/v1/base.py @@ -27,7 +27,7 @@ import logging logger = logging.getLogger(__name__) -def client_path_pattern(path_regex): +def client_path_patterns(path_regex, releases=(0,), include_in_unstable=True): """Creates a regex compiled client path with the correct client path prefix. @@ -37,7 +37,14 @@ def client_path_pattern(path_regex): Returns: SRE_Pattern """ - return re.compile("^" + CLIENT_PREFIX + path_regex) + patterns = [re.compile("^" + CLIENT_PREFIX + path_regex)] + if include_in_unstable: + unstable_prefix = CLIENT_PREFIX.replace("/api/v1", "/unstable") + patterns.append(re.compile("^" + unstable_prefix + path_regex)) + for release in releases: + new_prefix = CLIENT_PREFIX.replace("/api/v1", "/r%d" % release) + patterns.append(re.compile("^" + new_prefix + path_regex)) + return patterns class ClientV1RestServlet(RestServlet): diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 240eedac75..f488e2dd41 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import AuthError, SynapseError, Codes from synapse.types import RoomAlias -from .base import ClientV1RestServlet, client_path_pattern +from .base import ClientV1RestServlet, client_path_patterns import simplejson as json import logging @@ -32,7 +32,7 @@ def register_servlets(hs, http_server): class ClientDirectoryServer(ClientV1RestServlet): - PATTERN = client_path_pattern("/directory/room/(?P<room_alias>[^/]*)$") + PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$") @defer.inlineCallbacks def on_GET(self, request, room_alias): diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 3e1750d1a1..41b97e7d15 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.streams.config import PaginationConfig -from .base import ClientV1RestServlet, client_path_pattern +from .base import ClientV1RestServlet, client_path_patterns from synapse.events.utils import serialize_event import logging @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) class EventStreamRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/events$") + PATTERNS = client_path_patterns("/events$") DEFAULT_LONGPOLL_TIME_MS = 30000 @@ -72,7 +72,7 @@ class EventStreamRestServlet(ClientV1RestServlet): # TODO: Unit test gets, with and without auth, with different kinds of events. class EventRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/events/(?P<event_id>[^/]*)$") + PATTERNS = client_path_patterns("/events/(?P<event_id>[^/]*)$") def __init__(self, hs): super(EventRestServlet, self).__init__(hs) diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 856a70f297..9ad3df8a9f 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -16,12 +16,12 @@ from twisted.internet import defer from synapse.streams.config import PaginationConfig -from base import ClientV1RestServlet, client_path_pattern +from base import ClientV1RestServlet, client_path_patterns # TODO: Needs unit testing class InitialSyncRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/initialSync$") + PATTERNS = client_path_patterns("/initialSync$") @defer.inlineCallbacks def on_GET(self, request): diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 4ea06c1434..776e1667c1 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -16,12 +16,12 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, LoginError, Codes -from synapse.http.client import SimpleHttpClient from synapse.types import UserID -from base import ClientV1RestServlet, client_path_pattern +from base import ClientV1RestServlet, client_path_patterns import simplejson as json import urllib +import urlparse import logging from saml2 import BINDING_HTTP_POST @@ -35,10 +35,11 @@ logger = logging.getLogger(__name__) class LoginRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/login$") + PATTERNS = client_path_patterns("/login$", releases=(), include_in_unstable=False) PASS_TYPE = "m.login.password" SAML2_TYPE = "m.login.saml2" CAS_TYPE = "m.login.cas" + TOKEN_TYPE = "m.login.token" def __init__(self, hs): super(LoginRestServlet, self).__init__(hs) @@ -49,6 +50,7 @@ class LoginRestServlet(ClientV1RestServlet): self.cas_server_url = hs.config.cas_server_url self.cas_required_attributes = hs.config.cas_required_attributes self.servername = hs.config.server_name + self.http_client = hs.get_simple_http_client() def on_GET(self, request): flows = [] @@ -56,8 +58,18 @@ class LoginRestServlet(ClientV1RestServlet): flows.append({"type": LoginRestServlet.SAML2_TYPE}) if self.cas_enabled: flows.append({"type": LoginRestServlet.CAS_TYPE}) + + # While its valid for us to advertise this login type generally, + # synapse currently only gives out these tokens as part of the + # CAS login flow. + # Generally we don't want to advertise login flows that clients + # don't know how to implement, since they (currently) will always + # fall back to the fallback API if they don't understand one of the + # login flow types returned. + flows.append({"type": LoginRestServlet.TOKEN_TYPE}) if self.password_enabled: flows.append({"type": LoginRestServlet.PASS_TYPE}) + return (200, {"flows": flows}) def on_OPTIONS(self, request): @@ -83,19 +95,20 @@ class LoginRestServlet(ClientV1RestServlet): "uri": "%s%s" % (self.idp_redirect_url, relay_state) } defer.returnValue((200, result)) + # TODO Delete this after all CAS clients switch to token login instead elif self.cas_enabled and (login_submission["type"] == LoginRestServlet.CAS_TYPE): - # TODO: get this from the homeserver rather than creating a new one for - # each request - http_client = SimpleHttpClient(self.hs) uri = "%s/proxyValidate" % (self.cas_server_url,) args = { "ticket": login_submission["ticket"], "service": login_submission["service"] } - body = yield http_client.get_raw(uri, args) + body = yield self.http_client.get_raw(uri, args) result = yield self.do_cas_login(body) defer.returnValue(result) + elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: + result = yield self.do_token_login(login_submission) + defer.returnValue(result) else: raise SynapseError(400, "Bad login type.") except KeyError: @@ -132,6 +145,26 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue((200, result)) @defer.inlineCallbacks + def do_token_login(self, login_submission): + token = login_submission['token'] + auth_handler = self.handlers.auth_handler + user_id = ( + yield auth_handler.validate_short_term_login_token_and_get_user_id(token) + ) + user_id, access_token, refresh_token = ( + yield auth_handler.get_login_tuple_for_user_id(user_id) + ) + result = { + "user_id": user_id, # may have changed + "access_token": access_token, + "refresh_token": refresh_token, + "home_server": self.hs.hostname, + } + + defer.returnValue((200, result)) + + # TODO Delete this after all CAS clients switch to token login instead + @defer.inlineCallbacks def do_cas_login(self, cas_response_body): user, attributes = self.parse_cas_response(cas_response_body) @@ -152,7 +185,7 @@ class LoginRestServlet(ClientV1RestServlet): user_exists = yield auth_handler.does_user_exist(user_id) if user_exists: user_id, access_token, refresh_token = ( - yield auth_handler.login_with_cas_user_id(user_id) + yield auth_handler.get_login_tuple_for_user_id(user_id) ) result = { "user_id": user_id, # may have changed @@ -173,6 +206,7 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue((200, result)) + # TODO Delete this after all CAS clients switch to token login instead def parse_cas_response(self, cas_response_body): root = ET.fromstring(cas_response_body) if not root.tag.endswith("serviceResponse"): @@ -201,7 +235,7 @@ class LoginRestServlet(ClientV1RestServlet): class SAML2RestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/login/saml2") + PATTERNS = client_path_patterns("/login/saml2", releases=()) def __init__(self, hs): super(SAML2RestServlet, self).__init__(hs) @@ -243,8 +277,9 @@ class SAML2RestServlet(ClientV1RestServlet): defer.returnValue((200, {"status": "not_authenticated"})) +# TODO Delete this after all CAS clients switch to token login instead class CasRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/login/cas") + PATTERNS = client_path_patterns("/login/cas", releases=()) def __init__(self, hs): super(CasRestServlet, self).__init__(hs) @@ -254,6 +289,115 @@ class CasRestServlet(ClientV1RestServlet): return (200, {"serverUrl": self.cas_server_url}) +class CasRedirectServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns("/login/cas/redirect", releases=()) + + def __init__(self, hs): + super(CasRedirectServlet, self).__init__(hs) + self.cas_server_url = hs.config.cas_server_url + self.cas_service_url = hs.config.cas_service_url + + def on_GET(self, request): + args = request.args + if "redirectUrl" not in args: + return (400, "Redirect URL not specified for CAS auth") + client_redirect_url_param = urllib.urlencode({ + "redirectUrl": args["redirectUrl"][0] + }) + hs_redirect_url = self.cas_service_url + "/_matrix/client/api/v1/login/cas/ticket" + service_param = urllib.urlencode({ + "service": "%s?%s" % (hs_redirect_url, client_redirect_url_param) + }) + request.redirect("%s?%s" % (self.cas_server_url, service_param)) + request.finish() + + +class CasTicketServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns("/login/cas/ticket", releases=()) + + def __init__(self, hs): + super(CasTicketServlet, self).__init__(hs) + self.cas_server_url = hs.config.cas_server_url + self.cas_service_url = hs.config.cas_service_url + self.cas_required_attributes = hs.config.cas_required_attributes + + @defer.inlineCallbacks + def on_GET(self, request): + client_redirect_url = request.args["redirectUrl"][0] + http_client = self.hs.get_simple_http_client() + uri = self.cas_server_url + "/proxyValidate" + args = { + "ticket": request.args["ticket"], + "service": self.cas_service_url + } + body = yield http_client.get_raw(uri, args) + result = yield self.handle_cas_response(request, body, client_redirect_url) + defer.returnValue(result) + + @defer.inlineCallbacks + def handle_cas_response(self, request, cas_response_body, client_redirect_url): + user, attributes = self.parse_cas_response(cas_response_body) + + for required_attribute, required_value in self.cas_required_attributes.items(): + # If required attribute was not in CAS Response - Forbidden + if required_attribute not in attributes: + raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) + + # Also need to check value + if required_value is not None: + actual_value = attributes[required_attribute] + # If required attribute value does not match expected - Forbidden + if required_value != actual_value: + raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) + + user_id = UserID.create(user, self.hs.hostname).to_string() + auth_handler = self.handlers.auth_handler + user_exists = yield auth_handler.does_user_exist(user_id) + if not user_exists: + user_id, _ = ( + yield self.handlers.registration_handler.register(localpart=user) + ) + + login_token = auth_handler.generate_short_term_login_token(user_id) + redirect_url = self.add_login_token_to_redirect_url(client_redirect_url, + login_token) + request.redirect(redirect_url) + request.finish() + + def add_login_token_to_redirect_url(self, url, token): + url_parts = list(urlparse.urlparse(url)) + query = dict(urlparse.parse_qsl(url_parts[4])) + query.update({"loginToken": token}) + url_parts[4] = urllib.urlencode(query) + return urlparse.urlunparse(url_parts) + + def parse_cas_response(self, cas_response_body): + root = ET.fromstring(cas_response_body) + if not root.tag.endswith("serviceResponse"): + raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) + if not root[0].tag.endswith("authenticationSuccess"): + raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED) + for child in root[0]: + if child.tag.endswith("user"): + user = child.text + if child.tag.endswith("attributes"): + attributes = {} + for attribute in child: + # ElementTree library expands the namespace in attribute tags + # to the full URL of the namespace. + # See (https://docs.python.org/2/library/xml.etree.elementtree.html) + # We don't care about namespace here and it will always be encased in + # curly braces, so we remove them. + if "}" in attribute.tag: + attributes[attribute.tag.split("}")[1]] = attribute.text + else: + attributes[attribute.tag] = attribute.text + if user is None or attributes is None: + raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) + + return (user, attributes) + + def _parse_json(request): try: content = json.loads(request.content.read()) @@ -269,5 +413,7 @@ def register_servlets(hs, http_server): if hs.config.saml2_enabled: SAML2RestServlet(hs).register(http_server) if hs.config.cas_enabled: + CasRedirectServlet(hs).register(http_server) + CasTicketServlet(hs).register(http_server) CasRestServlet(hs).register(http_server) # TODO PasswordResetRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 6fe5d19a22..e0949fe4bb 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -19,7 +19,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.types import UserID -from .base import ClientV1RestServlet, client_path_pattern +from .base import ClientV1RestServlet, client_path_patterns import simplejson as json import logging @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) class PresenceStatusRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/presence/(?P<user_id>[^/]*)/status") + PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status") @defer.inlineCallbacks def on_GET(self, request, user_id): @@ -73,7 +73,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): class PresenceListRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/presence/list/(?P<user_id>[^/]*)") + PATTERNS = client_path_patterns("/presence/list/(?P<user_id>[^/]*)") @defer.inlineCallbacks def on_GET(self, request, user_id): @@ -120,7 +120,7 @@ class PresenceListRestServlet(ClientV1RestServlet): if len(u) == 0: continue invited_user = UserID.from_string(u) - yield self.handlers.presence_handler.send_invite( + yield self.handlers.presence_handler.send_presence_invite( observer_user=user, observed_user=invited_user ) diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index 3218e47025..e6c6e5d024 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -16,14 +16,14 @@ """ This module contains REST servlets to do with profile: /profile/<paths> """ from twisted.internet import defer -from .base import ClientV1RestServlet, client_path_pattern +from .base import ClientV1RestServlet, client_path_patterns from synapse.types import UserID import simplejson as json class ProfileDisplaynameRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)/displayname") + PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname") @defer.inlineCallbacks def on_GET(self, request, user_id): @@ -56,7 +56,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): class ProfileAvatarURLRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)/avatar_url") + PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url") @defer.inlineCallbacks def on_GET(self, request, user_id): @@ -89,7 +89,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): class ProfileRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)") + PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)") @defer.inlineCallbacks def on_GET(self, request, user_id): diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index b0870db1ac..9270bdd079 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import ( SynapseError, Codes, UnrecognizedRequestError, NotFoundError, StoreError ) -from .base import ClientV1RestServlet, client_path_pattern +from .base import ClientV1RestServlet, client_path_patterns from synapse.storage.push_rule import ( InconsistentRuleException, RuleNotFoundException ) @@ -31,7 +31,7 @@ import simplejson as json class PushRuleRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/pushrules/.*$") + PATTERNS = client_path_patterns("/pushrules/.*$") SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( "Unrecognised request: You probably wanted a trailing slash") @@ -207,7 +207,12 @@ class PushRuleRestServlet(ClientV1RestServlet): def set_rule_attr(self, user_name, spec, val): if spec['attr'] == 'enabled': + if isinstance(val, dict) and "enabled" in val: + val = val["enabled"] if not isinstance(val, bool): + # Legacy fallback + # This should *actually* take a dict, but many clients pass + # bools directly, so let's not break them. raise SynapseError(400, "Value for 'enabled' must be boolean") namespaced_rule_id = _namespaced_rule_id_from_spec(spec) self.hs.get_datastore().set_push_rule_enabled( diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index a110c0a4f0..d6d1ad528e 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -17,13 +17,16 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, Codes from synapse.push import PusherConfigException -from .base import ClientV1RestServlet, client_path_pattern +from .base import ClientV1RestServlet, client_path_patterns import simplejson as json +import logging + +logger = logging.getLogger(__name__) class PusherRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/pushers/set$") + PATTERNS = client_path_patterns("/pushers/set$") @defer.inlineCallbacks def on_POST(self, request): @@ -51,6 +54,9 @@ class PusherRestServlet(ClientV1RestServlet): raise SynapseError(400, "Missing parameters: "+','.join(missing), errcode=Codes.MISSING_PARAM) + logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind']) + logger.debug("Got pushers request with body: %r", content) + append = False if 'append' in content: append = content['append'] diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index a56834e365..4b02311e05 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, Codes from synapse.api.constants import LoginType -from base import ClientV1RestServlet, client_path_pattern +from base import ClientV1RestServlet, client_path_patterns import synapse.util.stringutils as stringutils from synapse.util.async import run_on_reactor @@ -48,7 +48,7 @@ class RegisterRestServlet(ClientV1RestServlet): handler doesn't have a concept of multi-stages or sessions. """ - PATTERN = client_path_pattern("/register$") + PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False) def __init__(self, hs): super(RegisterRestServlet, self).__init__(hs) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 139dac1cc3..53cc29becb 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -16,7 +16,7 @@ """ This module contains REST servlets to do with rooms: /rooms/<paths> """ from twisted.internet import defer -from base import ClientV1RestServlet, client_path_pattern +from base import ClientV1RestServlet, client_path_patterns from synapse.api.errors import SynapseError, Codes, AuthError from synapse.streams.config import PaginationConfig from synapse.api.constants import EventTypes, Membership @@ -34,16 +34,16 @@ class RoomCreateRestServlet(ClientV1RestServlet): # No PATTERN; we have custom dispatch rules here def register(self, http_server): - PATTERN = "/createRoom" - register_txn_path(self, PATTERN, http_server) + PATTERNS = "/createRoom" + register_txn_path(self, PATTERNS, http_server) # define CORS for all of /rooms in RoomCreateRestServlet for simplicity - http_server.register_path("OPTIONS", - client_path_pattern("/rooms(?:/.*)?$"), - self.on_OPTIONS) + http_server.register_paths("OPTIONS", + client_path_patterns("/rooms(?:/.*)?$"), + self.on_OPTIONS) # define CORS for /createRoom[/txnid] - http_server.register_path("OPTIONS", - client_path_pattern("/createRoom(?:/.*)?$"), - self.on_OPTIONS) + http_server.register_paths("OPTIONS", + client_path_patterns("/createRoom(?:/.*)?$"), + self.on_OPTIONS) @defer.inlineCallbacks def on_PUT(self, request, txn_id): @@ -103,18 +103,18 @@ class RoomStateEventRestServlet(ClientV1RestServlet): state_key = ("/rooms/(?P<room_id>[^/]*)/state/" "(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$") - http_server.register_path("GET", - client_path_pattern(state_key), - self.on_GET) - http_server.register_path("PUT", - client_path_pattern(state_key), - self.on_PUT) - http_server.register_path("GET", - client_path_pattern(no_state_key), - self.on_GET_no_state_key) - http_server.register_path("PUT", - client_path_pattern(no_state_key), - self.on_PUT_no_state_key) + http_server.register_paths("GET", + client_path_patterns(state_key), + self.on_GET) + http_server.register_paths("PUT", + client_path_patterns(state_key), + self.on_PUT) + http_server.register_paths("GET", + client_path_patterns(no_state_key), + self.on_GET_no_state_key) + http_server.register_paths("PUT", + client_path_patterns(no_state_key), + self.on_PUT_no_state_key) def on_GET_no_state_key(self, request, room_id, event_type): return self.on_GET(request, room_id, event_type, "") @@ -170,8 +170,8 @@ class RoomSendEventRestServlet(ClientV1RestServlet): def register(self, http_server): # /rooms/$roomid/send/$event_type[/$txn_id] - PATTERN = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)") - register_txn_path(self, PATTERN, http_server, with_get=True) + PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)") + register_txn_path(self, PATTERNS, http_server, with_get=True) @defer.inlineCallbacks def on_POST(self, request, room_id, event_type, txn_id=None): @@ -215,8 +215,8 @@ class JoinRoomAliasServlet(ClientV1RestServlet): def register(self, http_server): # /join/$room_identifier[/$txn_id] - PATTERN = ("/join/(?P<room_identifier>[^/]*)") - register_txn_path(self, PATTERN, http_server) + PATTERNS = ("/join/(?P<room_identifier>[^/]*)") + register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks def on_POST(self, request, room_identifier, txn_id=None): @@ -280,7 +280,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): # TODO: Needs unit testing class PublicRoomListRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/publicRooms$") + PATTERNS = client_path_patterns("/publicRooms$") @defer.inlineCallbacks def on_GET(self, request): @@ -291,7 +291,7 @@ class PublicRoomListRestServlet(ClientV1RestServlet): # TODO: Needs unit testing class RoomMemberListRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/members$") + PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$") @defer.inlineCallbacks def on_GET(self, request, room_id): @@ -328,7 +328,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet): # TODO: Needs better unit testing class RoomMessageListRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/messages$") + PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$") @defer.inlineCallbacks def on_GET(self, request, room_id): @@ -351,7 +351,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet): # TODO: Needs unit testing class RoomStateRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/state$") + PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$") @defer.inlineCallbacks def on_GET(self, request, room_id): @@ -368,7 +368,7 @@ class RoomStateRestServlet(ClientV1RestServlet): # TODO: Needs unit testing class RoomInitialSyncRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/initialSync$") + PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$") @defer.inlineCallbacks def on_GET(self, request, room_id): @@ -383,32 +383,8 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet): defer.returnValue((200, content)) -class RoomTriggerBackfill(ClientV1RestServlet): - PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/backfill$") - - def __init__(self, hs): - super(RoomTriggerBackfill, self).__init__(hs) - self.clock = hs.get_clock() - - @defer.inlineCallbacks - def on_GET(self, request, room_id): - remote_server = urllib.unquote( - request.args["remote"][0] - ).decode("UTF-8") - - limit = int(request.args["limit"][0]) - - handler = self.handlers.federation_handler - events = yield handler.backfill(remote_server, room_id, limit) - - time_now = self.clock.time_msec() - - res = [serialize_event(event, time_now) for event in events] - defer.returnValue((200, res)) - - class RoomEventContext(ClientV1RestServlet): - PATTERN = client_path_pattern( + PATTERNS = client_path_patterns( "/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$" ) @@ -447,9 +423,9 @@ class RoomMembershipRestServlet(ClientV1RestServlet): def register(self, http_server): # /rooms/$roomid/[invite|join|leave] - PATTERN = ("/rooms/(?P<room_id>[^/]*)/" - "(?P<membership_action>join|invite|leave|ban|kick)") - register_txn_path(self, PATTERN, http_server) + PATTERNS = ("/rooms/(?P<room_id>[^/]*)/" + "(?P<membership_action>join|invite|leave|ban|kick|forget)") + register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks def on_POST(self, request, room_id, membership_action, txn_id=None): @@ -458,6 +434,8 @@ class RoomMembershipRestServlet(ClientV1RestServlet): allow_guest=True ) + effective_membership_action = membership_action + if is_guest and membership_action not in {Membership.JOIN, Membership.LEAVE}: raise AuthError(403, "Guest access not allowed") @@ -488,11 +466,13 @@ class RoomMembershipRestServlet(ClientV1RestServlet): UserID.from_string(state_key) if membership_action == "kick": - membership_action = "leave" + effective_membership_action = "leave" + elif membership_action == "forget": + effective_membership_action = "leave" msg_handler = self.handlers.message_handler - content = {"membership": unicode(membership_action)} + content = {"membership": unicode(effective_membership_action)} if is_guest: content["kind"] = "guest" @@ -509,6 +489,9 @@ class RoomMembershipRestServlet(ClientV1RestServlet): is_guest=is_guest, ) + if membership_action == "forget": + self.handlers.room_member_handler.forget(user, room_id) + defer.returnValue((200, {})) def _has_3pid_invite_keys(self, content): @@ -536,8 +519,8 @@ class RoomMembershipRestServlet(ClientV1RestServlet): class RoomRedactEventRestServlet(ClientV1RestServlet): def register(self, http_server): - PATTERN = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)") - register_txn_path(self, PATTERN, http_server) + PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)") + register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks def on_POST(self, request, room_id, event_id, txn_id=None): @@ -575,7 +558,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): class RoomTypingRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern( + PATTERNS = client_path_patterns( "/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$" ) @@ -608,7 +591,7 @@ class RoomTypingRestServlet(ClientV1RestServlet): class SearchRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern( + PATTERNS = client_path_patterns( "/search$" ) @@ -648,20 +631,20 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False): http_server : The http_server to register paths with. with_get: True to also register respective GET paths for the PUTs. """ - http_server.register_path( + http_server.register_paths( "POST", - client_path_pattern(regex_string + "$"), + client_path_patterns(regex_string + "$"), servlet.on_POST ) - http_server.register_path( + http_server.register_paths( "PUT", - client_path_pattern(regex_string + "/(?P<txn_id>[^/]*)$"), + client_path_patterns(regex_string + "/(?P<txn_id>[^/]*)$"), servlet.on_PUT ) if with_get: - http_server.register_path( + http_server.register_paths( "GET", - client_path_pattern(regex_string + "/(?P<txn_id>[^/]*)$"), + client_path_patterns(regex_string + "/(?P<txn_id>[^/]*)$"), servlet.on_GET ) @@ -672,7 +655,6 @@ def register_servlets(hs, http_server): RoomMemberListRestServlet(hs).register(http_server) RoomMessageListRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) - RoomTriggerBackfill(hs).register(http_server) RoomMembershipRestServlet(hs).register(http_server) RoomSendEventRestServlet(hs).register(http_server) PublicRoomListRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index eb7c57cade..1567a03c89 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -15,7 +15,7 @@ from twisted.internet import defer -from base import ClientV1RestServlet, client_path_pattern +from base import ClientV1RestServlet, client_path_patterns import hmac @@ -24,7 +24,7 @@ import base64 class VoipRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/voip/turnServer$") + PATTERNS = client_path_patterns("/voip/turnServer$") @defer.inlineCallbacks def on_GET(self, request): diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py index a108132346..c488b10d3c 100644 --- a/synapse/rest/client/v2_alpha/__init__.py +++ b/synapse/rest/client/v2_alpha/__init__.py @@ -12,37 +12,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from . import ( - sync, - filter, - account, - register, - auth, - receipts, - keys, - tokenrefresh, - tags, -) - -from synapse.http.server import JsonResource - - -class ClientV2AlphaRestResource(JsonResource): - """A resource for version 2 alpha of the matrix client API.""" - - def __init__(self, hs): - JsonResource.__init__(self, hs, canonical_json=False) - self.register_servlets(self, hs) - - @staticmethod - def register_servlets(client_resource, hs): - sync.register_servlets(hs, client_resource) - filter.register_servlets(hs, client_resource) - account.register_servlets(hs, client_resource) - register.register_servlets(hs, client_resource) - auth.register_servlets(hs, client_resource) - receipts.register_servlets(hs, client_resource) - keys.register_servlets(hs, client_resource) - tokenrefresh.register_servlets(hs, client_resource) - tags.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index 4540e8dcf7..7b8b879c03 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -27,7 +27,7 @@ import simplejson logger = logging.getLogger(__name__) -def client_v2_pattern(path_regex): +def client_v2_patterns(path_regex, releases=(0,)): """Creates a regex compiled client path with the correct client path prefix. @@ -37,7 +37,13 @@ def client_v2_pattern(path_regex): Returns: SRE_Pattern """ - return re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex) + patterns = [re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)] + unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable") + patterns.append(re.compile("^" + unstable_prefix + path_regex)) + for release in releases: + new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release) + patterns.append(re.compile("^" + new_prefix + path_regex)) + return patterns def parse_request_allow_empty(request): diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 1970ad3458..3e1459d5b9 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -20,7 +20,7 @@ from synapse.api.errors import LoginError, SynapseError, Codes from synapse.http.servlet import RestServlet from synapse.util.async import run_on_reactor -from ._base import client_v2_pattern, parse_json_dict_from_request +from ._base import client_v2_patterns, parse_json_dict_from_request import logging @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class PasswordRestServlet(RestServlet): - PATTERN = client_v2_pattern("/account/password") + PATTERNS = client_v2_patterns("/account/password") def __init__(self, hs): super(PasswordRestServlet, self).__init__() @@ -89,7 +89,7 @@ class PasswordRestServlet(RestServlet): class ThreepidRestServlet(RestServlet): - PATTERN = client_v2_pattern("/account/3pid") + PATTERNS = client_v2_patterns("/account/3pid") def __init__(self, hs): super(ThreepidRestServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py new file mode 100644 index 0000000000..5b8f454bf1 --- /dev/null +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import client_v2_patterns + +from synapse.http.servlet import RestServlet +from synapse.api.errors import AuthError, SynapseError + +from twisted.internet import defer + +import logging + +import simplejson as json + +logger = logging.getLogger(__name__) + + +class AccountDataServlet(RestServlet): + """ + PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1 + """ + PATTERNS = client_v2_patterns( + "/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)" + ) + + def __init__(self, hs): + super(AccountDataServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.notifier = hs.get_notifier() + + @defer.inlineCallbacks + def on_PUT(self, request, user_id, account_data_type): + auth_user, _, _ = yield self.auth.get_user_by_req(request) + if user_id != auth_user.to_string(): + raise AuthError(403, "Cannot add account data for other users.") + + try: + content_bytes = request.content.read() + body = json.loads(content_bytes) + except: + raise SynapseError(400, "Invalid JSON") + + max_id = yield self.store.add_account_data_for_user( + user_id, account_data_type, body + ) + + yield self.notifier.on_new_event( + "account_data_key", max_id, users=[user_id] + ) + + defer.returnValue((200, {})) + + +class RoomAccountDataServlet(RestServlet): + """ + PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 + """ + PATTERNS = client_v2_patterns( + "/user/(?P<user_id>[^/]*)" + "/rooms/(?P<room_id>[^/]*)" + "/account_data/(?P<account_data_type>[^/]*)" + ) + + def __init__(self, hs): + super(RoomAccountDataServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.notifier = hs.get_notifier() + + @defer.inlineCallbacks + def on_PUT(self, request, user_id, room_id, account_data_type): + auth_user, _, _ = yield self.auth.get_user_by_req(request) + if user_id != auth_user.to_string(): + raise AuthError(403, "Cannot add account data for other users.") + + try: + content_bytes = request.content.read() + body = json.loads(content_bytes) + except: + raise SynapseError(400, "Invalid JSON") + + if not isinstance(body, dict): + raise ValueError("Expected a JSON object") + + max_id = yield self.store.add_account_data_to_room( + user_id, room_id, account_data_type, body + ) + + yield self.notifier.on_new_event( + "account_data_key", max_id, users=[user_id] + ) + + defer.returnValue((200, {})) + + +def register_servlets(hs, http_server): + AccountDataServlet(hs).register(http_server) + RoomAccountDataServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 4c726f05f5..fb5947a141 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -20,7 +20,7 @@ from synapse.api.errors import SynapseError from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX from synapse.http.servlet import RestServlet -from ._base import client_v2_pattern +from ._base import client_v2_patterns import logging @@ -97,7 +97,7 @@ class AuthRestServlet(RestServlet): cannot be handled in the normal flow (with requests to the same endpoint). Current use is for web fallback auth. """ - PATTERN = client_v2_pattern("/auth/(?P<stagetype>[\w\.]*)/fallback/web") + PATTERNS = client_v2_patterns("/auth/(?P<stagetype>[\w\.]*)/fallback/web") def __init__(self, hs): super(AuthRestServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index 97956a4b91..3cd0364b56 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -19,7 +19,7 @@ from synapse.api.errors import AuthError, SynapseError from synapse.http.servlet import RestServlet from synapse.types import UserID -from ._base import client_v2_pattern +from ._base import client_v2_patterns import simplejson as json import logging @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class GetFilterRestServlet(RestServlet): - PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)") + PATTERNS = client_v2_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)") def __init__(self, hs): super(GetFilterRestServlet, self).__init__() @@ -65,7 +65,7 @@ class GetFilterRestServlet(RestServlet): class CreateFilterRestServlet(RestServlet): - PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter") + PATTERNS = client_v2_patterns("/user/(?P<user_id>[^/]*)/filter") def __init__(self, hs): super(CreateFilterRestServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 820d33336f..753f2988a1 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -21,7 +21,7 @@ from synapse.types import UserID from canonicaljson import encode_canonical_json -from ._base import client_v2_pattern +from ._base import client_v2_patterns import simplejson as json import logging @@ -54,7 +54,7 @@ class KeyUploadServlet(RestServlet): }, } """ - PATTERN = client_v2_pattern("/keys/upload/(?P<device_id>[^/]*)") + PATTERNS = client_v2_patterns("/keys/upload/(?P<device_id>[^/]*)", releases=()) def __init__(self, hs): super(KeyUploadServlet, self).__init__() @@ -154,12 +154,13 @@ class KeyQueryServlet(RestServlet): } } } } } } """ - PATTERN = client_v2_pattern( + PATTERNS = client_v2_patterns( "/keys/query(?:" "/(?P<user_id>[^/]*)(?:" "/(?P<device_id>[^/]*)" ")?" - ")?" + ")?", + releases=() ) def __init__(self, hs): @@ -245,10 +246,11 @@ class OneTimeKeyServlet(RestServlet): } } } } """ - PATTERN = client_v2_pattern( + PATTERNS = client_v2_patterns( "/keys/claim(?:/?|(?:/" "(?P<user_id>[^/]*)/(?P<device_id>[^/]*)/(?P<algorithm>[^/]*)" - ")?)" + ")?)", + releases=() ) def __init__(self, hs): diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 788acd4adb..aa214e13b6 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -17,7 +17,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet -from ._base import client_v2_pattern +from ._base import client_v2_patterns import logging @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) class ReceiptRestServlet(RestServlet): - PATTERN = client_v2_pattern( + PATTERNS = client_v2_patterns( "/rooms/(?P<room_id>[^/]*)" "/receipt/(?P<receipt_type>[^/]*)" "/(?P<event_id>[^/]*)$" diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index f899376311..b2b89652c6 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -19,7 +19,7 @@ from synapse.api.constants import LoginType from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.http.servlet import RestServlet -from ._base import client_v2_pattern, parse_json_dict_from_request +from ._base import client_v2_patterns, parse_json_dict_from_request import logging import hmac @@ -41,7 +41,7 @@ logger = logging.getLogger(__name__) class RegisterRestServlet(RestServlet): - PATTERN = client_v2_pattern("/register") + PATTERNS = client_v2_patterns("/register") def __init__(self, hs): super(RegisterRestServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index efd8281558..f0a637a6da 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -22,14 +22,17 @@ from synapse.handlers.sync import SyncConfig from synapse.types import StreamToken from synapse.events import FrozenEvent from synapse.events.utils import ( - serialize_event, format_event_for_client_v2_without_event_id, + serialize_event, format_event_for_client_v2_without_room_id, ) from synapse.api.filtering import FilterCollection -from ._base import client_v2_pattern +from synapse.api.errors import SynapseError +from ._base import client_v2_patterns import copy import logging +import ujson as json + logger = logging.getLogger(__name__) @@ -48,7 +51,7 @@ class SyncRestServlet(RestServlet): "next_batch": // batch token for the next /sync "presence": // presence data for the user. "rooms": { - "joined": { // Joined rooms being updated. + "join": { // Joined rooms being updated. "${room_id}": { // Id of the room being updated "event_map": // Map of EventID -> event JSON. "timeline": { // The recent events in the room if gap is "true" @@ -63,13 +66,13 @@ class SyncRestServlet(RestServlet): "ephemeral": {"events": []} // list of event objects } }, - "invited": {}, // Invited rooms being updated. - "archived": {} // Archived rooms being updated. + "invite": {}, // Invited rooms being updated. + "leave": {} // Archived rooms being updated. } } """ - PATTERN = client_v2_pattern("/sync$") + PATTERNS = client_v2_patterns("/sync$") ALLOWED_PRESENCE = set(["online", "offline"]) def __init__(self, hs): @@ -100,12 +103,21 @@ class SyncRestServlet(RestServlet): ) ) - try: - filter = yield self.filtering.get_user_filter( - user.localpart, filter_id - ) - except: - filter = FilterCollection({}) + if filter_id and filter_id.startswith('{'): + logging.error("MJH %r", filter_id) + try: + filter_object = json.loads(filter_id) + except: + raise SynapseError(400, "Invalid filter JSON") + self.filtering._check_valid_filter(filter_object) + filter = FilterCollection(filter_object) + else: + try: + filter = yield self.filtering.get_user_filter( + user.localpart, filter_id + ) + except: + filter = FilterCollection({}) sync_config = SyncConfig( user=user, @@ -144,13 +156,16 @@ class SyncRestServlet(RestServlet): ) response_content = { + "account_data": self.encode_account_data( + sync_result.account_data, filter, time_now + ), "presence": self.encode_presence( sync_result.presence, filter, time_now ), "rooms": { - "joined": joined, - "invited": invited, - "archived": archived, + "join": joined, + "invite": invited, + "leave": archived, }, "next_batch": sync_result.next_batch.to_string(), } @@ -165,6 +180,9 @@ class SyncRestServlet(RestServlet): formatted.append(event) return {"events": filter.filter_presence(formatted)} + def encode_account_data(self, events, filter, time_now): + return {"events": filter.filter_account_data(events)} + def encode_joined(self, rooms, filter, time_now, token_id): """ Encode the joined rooms in a sync result @@ -207,7 +225,7 @@ class SyncRestServlet(RestServlet): for room in rooms: invite = serialize_event( room.invite, time_now, token_id=token_id, - event_format=format_event_for_client_v2_without_event_id, + event_format=format_event_for_client_v2_without_room_id, ) invited_state = invite.get("unsigned", {}).pop("invite_room_state", []) invited_state.append(invite) @@ -256,7 +274,13 @@ class SyncRestServlet(RestServlet): :return: the room, encoded in our response format :rtype: dict[str, object] """ - event_map = {} + def serialize(event): + # TODO(mjark): Respect formatting requirements in the filter. + return serialize_event( + event, time_now, token_id=token_id, + event_format=format_event_for_client_v2_without_room_id, + ) + state_dict = room.state timeline_events = filter.filter_room_timeline(room.timeline.events) @@ -264,37 +288,22 @@ class SyncRestServlet(RestServlet): state_dict, timeline_events) state_events = filter.filter_room_state(state_dict.values()) - state_event_ids = [] - for event in state_events: - # TODO(mjark): Respect formatting requirements in the filter. - event_map[event.event_id] = serialize_event( - event, time_now, token_id=token_id, - event_format=format_event_for_client_v2_without_event_id, - ) - state_event_ids.append(event.event_id) - timeline_event_ids = [] - for event in timeline_events: - # TODO(mjark): Respect formatting requirements in the filter. - event_map[event.event_id] = serialize_event( - event, time_now, token_id=token_id, - event_format=format_event_for_client_v2_without_event_id, - ) - timeline_event_ids.append(event.event_id) + serialized_state = [serialize(e) for e in state_events] + serialized_timeline = [serialize(e) for e in timeline_events] - private_user_data = filter.filter_room_private_user_data( - room.private_user_data + account_data = filter.filter_room_account_data( + room.account_data ) result = { - "event_map": event_map, "timeline": { - "events": timeline_event_ids, + "events": serialized_timeline, "prev_batch": room.timeline.prev_batch.to_string(), "limited": room.timeline.limited, }, - "state": {"events": state_event_ids}, - "private_user_data": {"events": private_user_data}, + "state": {"events": serialized_state}, + "account_data": {"events": account_data}, } if joined: diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index 35482ae6a6..b5d0db5569 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import client_v2_pattern +from ._base import client_v2_patterns from synapse.http.servlet import RestServlet from synapse.api.errors import AuthError, SynapseError @@ -31,7 +31,7 @@ class TagListServlet(RestServlet): """ GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1 """ - PATTERN = client_v2_pattern( + PATTERNS = client_v2_patterns( "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags" ) @@ -56,7 +56,7 @@ class TagServlet(RestServlet): PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 """ - PATTERN = client_v2_pattern( + PATTERNS = client_v2_patterns( "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)" ) @@ -81,7 +81,7 @@ class TagServlet(RestServlet): max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body) yield self.notifier.on_new_event( - "private_user_data_key", max_id, users=[user_id] + "account_data_key", max_id, users=[user_id] ) defer.returnValue((200, {})) @@ -95,7 +95,7 @@ class TagServlet(RestServlet): max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag) yield self.notifier.on_new_event( - "private_user_data_key", max_id, users=[user_id] + "account_data_key", max_id, users=[user_id] ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index 901e777983..5a63afd51e 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.http.servlet import RestServlet -from ._base import client_v2_pattern, parse_json_dict_from_request +from ._base import client_v2_patterns, parse_json_dict_from_request class TokenRefreshRestServlet(RestServlet): @@ -26,7 +26,7 @@ class TokenRefreshRestServlet(RestServlet): Exchanges refresh tokens for a pair of an access token and a new refresh token. """ - PATTERN = client_v2_pattern("/tokenrefresh") + PATTERNS = client_v2_patterns("/tokenrefresh") def __init__(self, hs): super(TokenRefreshRestServlet, self).__init__() diff --git a/synapse/server.py b/synapse/server.py index f75d5358b2..f5c8329873 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -71,8 +71,7 @@ class BaseHomeServer(object): 'state_handler', 'notifier', 'distributor', - 'resource_for_client', - 'resource_for_client_v2_alpha', + 'client_resource', 'resource_for_federation', 'resource_for_static_content', 'resource_for_web_client', diff --git a/synapse/static/client/login/js/login.js b/synapse/static/client/login/js/login.js index ab8b4d44ea..bfb7386035 100644 --- a/synapse/static/client/login/js/login.js +++ b/synapse/static/client/login/js/login.js @@ -17,12 +17,11 @@ var submitPassword = function(user, pwd) { }).error(errorFunc); }; -var submitCas = function(ticket, service) { - console.log("Logging in with cas..."); +var submitToken = function(loginToken) { + console.log("Logging in with login token..."); var data = { - type: "m.login.cas", - ticket: ticket, - service: service, + type: "m.login.token", + token: loginToken }; $.post(matrixLogin.endpoint, JSON.stringify(data), function(response) { show_login(); @@ -41,23 +40,10 @@ var errorFunc = function(err) { } }; -var getCasURL = function(cb) { - $.get(matrixLogin.endpoint + "/cas", function(response) { - var cas_url = response.serverUrl; - - cb(cas_url); - }).error(errorFunc); -}; - - var gotoCas = function() { - getCasURL(function(cas_url) { - var this_page = window.location.origin + window.location.pathname; - - var redirect_url = cas_url + "/login?service=" + encodeURIComponent(this_page); - - window.location.replace(redirect_url); - }); + var this_page = window.location.origin + window.location.pathname; + var redirect_url = matrixLogin.endpoint + "/cas/redirect?redirectUrl=" + encodeURIComponent(this_page); + window.location.replace(redirect_url); } var setFeedbackString = function(text) { @@ -111,7 +97,7 @@ var fetch_info = function(cb) { matrixLogin.onLoad = function() { fetch_info(function() { - if (!try_cas()) { + if (!try_token()) { show_login(); } }); @@ -148,20 +134,20 @@ var parseQsFromUrl = function(query) { return result; }; -var try_cas = function() { +var try_token = function() { var pos = window.location.href.indexOf("?"); if (pos == -1) { return false; } var qs = parseQsFromUrl(window.location.href.substr(pos+1)); - var ticket = qs.ticket; + var loginToken = qs.loginToken; - if (!ticket) { + if (!loginToken) { return false; } - submitCas(ticket, location.origin); + submitToken(loginToken); return true; }; diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index e7443f2838..c46b653f11 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -42,6 +42,7 @@ from .end_to_end_keys import EndToEndKeyStore from .receipts import ReceiptsStore from .search import SearchStore from .tags import TagsStore +from .account_data import AccountDataStore import logging @@ -73,6 +74,7 @@ class DataStore(RoomMemberStore, RoomStore, EndToEndKeyStore, SearchStore, TagsStore, + AccountDataStore, ): def __init__(self, hs): diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 218e708054..17a14e001c 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -214,7 +214,8 @@ class SQLBaseStore(object): self._clock.looping_call(loop, 10000) - def _new_transaction(self, conn, desc, after_callbacks, func, *args, **kwargs): + def _new_transaction(self, conn, desc, after_callbacks, logging_context, + func, *args, **kwargs): start = time.time() * 1000 txn_id = self._TXN_ID @@ -277,6 +278,9 @@ class SQLBaseStore(object): end = time.time() * 1000 duration = end - start + if logging_context is not None: + logging_context.add_database_transaction(duration) + transaction_logger.debug("[TXN END] {%s} %f", name, duration) self._current_txn_total_time += duration @@ -302,7 +306,8 @@ class SQLBaseStore(object): current_context.copy_to(context) return self._new_transaction( - conn, desc, after_callbacks, func, *args, **kwargs + conn, desc, after_callbacks, current_context, + func, *args, **kwargs ) result = yield preserve_context_over_fn( diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py new file mode 100644 index 0000000000..d1829f84e8 --- /dev/null +++ b/synapse/storage/account_data.py @@ -0,0 +1,211 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import SQLBaseStore +from twisted.internet import defer + +import ujson as json +import logging + +logger = logging.getLogger(__name__) + + +class AccountDataStore(SQLBaseStore): + + def get_account_data_for_user(self, user_id): + """Get all the client account_data for a user. + + Args: + user_id(str): The user to get the account_data for. + Returns: + A deferred pair of a dict of global account_data and a dict + mapping from room_id string to per room account_data dicts. + """ + + def get_account_data_for_user_txn(txn): + rows = self._simple_select_list_txn( + txn, "account_data", {"user_id": user_id}, + ["account_data_type", "content"] + ) + + global_account_data = { + row["account_data_type"]: json.loads(row["content"]) for row in rows + } + + rows = self._simple_select_list_txn( + txn, "room_account_data", {"user_id": user_id}, + ["room_id", "account_data_type", "content"] + ) + + by_room = {} + for row in rows: + room_data = by_room.setdefault(row["room_id"], {}) + room_data[row["account_data_type"]] = json.loads(row["content"]) + + return (global_account_data, by_room) + + return self.runInteraction( + "get_account_data_for_user", get_account_data_for_user_txn + ) + + def get_account_data_for_room(self, user_id, room_id): + """Get all the client account_data for a user for a room. + + Args: + user_id(str): The user to get the account_data for. + room_id(str): The room to get the account_data for. + Returns: + A deferred dict of the room account_data + """ + def get_account_data_for_room_txn(txn): + rows = self._simple_select_list_txn( + txn, "room_account_data", {"user_id": user_id, "room_id": room_id}, + ["account_data_type", "content"] + ) + + return { + row["account_data_type"]: json.loads(row["content"]) for row in rows + } + + return self.runInteraction( + "get_account_data_for_room", get_account_data_for_room_txn + ) + + def get_updated_account_data_for_user(self, user_id, stream_id): + """Get all the client account_data for a that's changed. + + Args: + user_id(str): The user to get the account_data for. + stream_id(int): The point in the stream since which to get updates + Returns: + A deferred pair of a dict of global account_data and a dict + mapping from room_id string to per room account_data dicts. + """ + + def get_updated_account_data_for_user_txn(txn): + sql = ( + "SELECT account_data_type, content FROM account_data" + " WHERE user_id = ? AND stream_id > ?" + ) + + txn.execute(sql, (user_id, stream_id)) + + global_account_data = { + row[0]: json.loads(row[1]) for row in txn.fetchall() + } + + sql = ( + "SELECT room_id, account_data_type, content FROM room_account_data" + " WHERE user_id = ? AND stream_id > ?" + ) + + txn.execute(sql, (user_id, stream_id)) + + account_data_by_room = {} + for row in txn.fetchall(): + room_account_data = account_data_by_room.setdefault(row[0], {}) + room_account_data[row[1]] = json.loads(row[2]) + + return (global_account_data, account_data_by_room) + + return self.runInteraction( + "get_updated_account_data_for_user", get_updated_account_data_for_user_txn + ) + + @defer.inlineCallbacks + def add_account_data_to_room(self, user_id, room_id, account_data_type, content): + """Add some account_data to a room for a user. + Args: + user_id(str): The user to add a tag for. + room_id(str): The room to add a tag for. + account_data_type(str): The type of account_data to add. + content(dict): A json object to associate with the tag. + Returns: + A deferred that completes once the account_data has been added. + """ + content_json = json.dumps(content) + + def add_account_data_txn(txn, next_id): + self._simple_upsert_txn( + txn, + table="room_account_data", + keyvalues={ + "user_id": user_id, + "room_id": room_id, + "account_data_type": account_data_type, + }, + values={ + "stream_id": next_id, + "content": content_json, + } + ) + self._update_max_stream_id(txn, next_id) + + with (yield self._account_data_id_gen.get_next(self)) as next_id: + yield self.runInteraction( + "add_room_account_data", add_account_data_txn, next_id + ) + + result = yield self._account_data_id_gen.get_max_token(self) + defer.returnValue(result) + + @defer.inlineCallbacks + def add_account_data_for_user(self, user_id, account_data_type, content): + """Add some account_data to a room for a user. + Args: + user_id(str): The user to add a tag for. + account_data_type(str): The type of account_data to add. + content(dict): A json object to associate with the tag. + Returns: + A deferred that completes once the account_data has been added. + """ + content_json = json.dumps(content) + + def add_account_data_txn(txn, next_id): + self._simple_upsert_txn( + txn, + table="account_data", + keyvalues={ + "user_id": user_id, + "account_data_type": account_data_type, + }, + values={ + "stream_id": next_id, + "content": content_json, + } + ) + self._update_max_stream_id(txn, next_id) + + with (yield self._account_data_id_gen.get_next(self)) as next_id: + yield self.runInteraction( + "add_user_account_data", add_account_data_txn, next_id + ) + + result = yield self._account_data_id_gen.get_max_token(self) + defer.returnValue(result) + + def _update_max_stream_id(self, txn, next_id): + """Update the max stream_id + + Args: + txn: The database cursor + next_id(int): The the revision to advance to. + """ + update_max_id_sql = ( + "UPDATE account_data_max_stream_id" + " SET stream_id = ?" + " WHERE stream_id < ?" + ) + txn.execute(update_max_id_sql, (next_id, next_id)) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 5d35ca90b9..fc5725097c 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -51,6 +51,14 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events class EventsStore(SQLBaseStore): + EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" + + def __init__(self, hs): + super(EventsStore, self).__init__(hs) + self.register_background_update_handler( + self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts + ) + @defer.inlineCallbacks def persist_events(self, events_and_contexts, backfilled=False, is_new_state=True): @@ -365,6 +373,7 @@ class EventsStore(SQLBaseStore): "processed": True, "outlier": event.internal_metadata.is_outlier(), "content": encode_json(event.content).decode("UTF-8"), + "origin_server_ts": int(event.origin_server_ts), } for event, _ in events_and_contexts ], @@ -640,7 +649,7 @@ class EventsStore(SQLBaseStore): ] rows = self._new_transaction( - conn, "do_fetch", [], self._fetch_event_rows, event_ids + conn, "do_fetch", [], None, self._fetch_event_rows, event_ids ) row_dict = { @@ -964,3 +973,71 @@ class EventsStore(SQLBaseStore): ret = yield self.runInteraction("count_messages", _count_messages) defer.returnValue(ret) + + @defer.inlineCallbacks + def _background_reindex_origin_server_ts(self, progress, batch_size): + target_min_stream_id = progress["target_min_stream_id_inclusive"] + max_stream_id = progress["max_stream_id_exclusive"] + rows_inserted = progress.get("rows_inserted", 0) + + INSERT_CLUMP_SIZE = 1000 + + def reindex_search_txn(txn): + sql = ( + "SELECT stream_ordering, event_id FROM events" + " WHERE ? <= stream_ordering AND stream_ordering < ?" + " ORDER BY stream_ordering DESC" + " LIMIT ?" + ) + + txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) + + rows = txn.fetchall() + if not rows: + return 0 + + min_stream_id = rows[-1][0] + event_ids = [row[1] for row in rows] + + events = self._get_events_txn(txn, event_ids) + + rows = [] + for event in events: + try: + event_id = event.event_id + origin_server_ts = event.origin_server_ts + except (KeyError, AttributeError): + # If the event is missing a necessary field then + # skip over it. + continue + + rows.append((origin_server_ts, event_id)) + + sql = ( + "UPDATE events SET origin_server_ts = ? WHERE event_id = ?" + ) + + for index in range(0, len(rows), INSERT_CLUMP_SIZE): + clump = rows[index:index + INSERT_CLUMP_SIZE] + txn.executemany(sql, clump) + + progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + "rows_inserted": rows_inserted + len(rows) + } + + self._background_update_progress_txn( + txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress + ) + + return len(rows) + + result = yield self.runInteraction( + self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn + ) + + if not result: + yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME) + + defer.returnValue(result) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 1a74d6e360..16eff62544 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 25 +SCHEMA_VERSION = 27 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index ae1ad56d9a..69398b7c8e 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -160,7 +160,7 @@ class RoomMemberStore(SQLBaseStore): def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id, membership_list): - where_clause = "user_id = ? AND (%s)" % ( + where_clause = "user_id = ? AND (%s) AND forgotten = 0" % ( " OR ".join(["membership = ?" for _ in membership_list]), ) @@ -269,3 +269,67 @@ class RoomMemberStore(SQLBaseStore): ret = len(room_id_lists.pop(0).intersection(*room_id_lists)) > 0 defer.returnValue(ret) + + def forget(self, user_id, room_id): + """Indicate that user_id wishes to discard history for room_id.""" + def f(txn): + sql = ( + "UPDATE" + " room_memberships" + " SET" + " forgotten = 1" + " WHERE" + " user_id = ?" + " AND" + " room_id = ?" + ) + txn.execute(sql, (user_id, room_id)) + self.runInteraction("forget_membership", f) + + @defer.inlineCallbacks + def did_forget(self, user_id, room_id): + """Returns whether user_id has elected to discard history for room_id. + + Returns False if they have since re-joined.""" + def f(txn): + sql = ( + "SELECT" + " COUNT(*)" + " FROM" + " room_memberships" + " WHERE" + " user_id = ?" + " AND" + " room_id = ?" + " AND" + " forgotten = 0" + ) + txn.execute(sql, (user_id, room_id)) + rows = txn.fetchall() + return rows[0][0] + count = yield self.runInteraction("did_forget_membership", f) + defer.returnValue(count == 0) + + @defer.inlineCallbacks + def was_forgotten_at(self, user_id, room_id, event_id): + """Returns whether user_id has elected to discard history for room_id at event_id. + + event_id must be a membership event.""" + def f(txn): + sql = ( + "SELECT" + " forgotten" + " FROM" + " room_memberships" + " WHERE" + " user_id = ?" + " AND" + " room_id = ?" + " AND" + " event_id = ?" + ) + txn.execute(sql, (user_id, room_id, event_id)) + rows = txn.fetchall() + return rows[0][0] + forgot = yield self.runInteraction("did_forget_membership_at", f) + defer.returnValue(forgot == 1) diff --git a/synapse/storage/schema/delta/15/v15.sql b/synapse/storage/schema/delta/15/v15.sql index f5b2a08ca4..9523d2bcc3 100644 --- a/synapse/storage/schema/delta/15/v15.sql +++ b/synapse/storage/schema/delta/15/v15.sql @@ -1,23 +1,22 @@ -- Drop, copy & recreate pushers table to change unique key -- Also add access_token column at the same time CREATE TABLE IF NOT EXISTS pushers2 ( - id INTEGER PRIMARY KEY AUTOINCREMENT, + id BIGINT PRIMARY KEY, user_name TEXT NOT NULL, - access_token INTEGER DEFAULT NULL, - profile_tag varchar(32) NOT NULL, - kind varchar(8) NOT NULL, - app_id varchar(64) NOT NULL, - app_display_name varchar(64) NOT NULL, - device_display_name varchar(128) NOT NULL, - pushkey blob NOT NULL, + access_token BIGINT DEFAULT NULL, + profile_tag VARCHAR(32) NOT NULL, + kind VARCHAR(8) NOT NULL, + app_id VARCHAR(64) NOT NULL, + app_display_name VARCHAR(64) NOT NULL, + device_display_name VARCHAR(128) NOT NULL, + pushkey bytea NOT NULL, ts BIGINT NOT NULL, - lang varchar(8), - data blob, + lang VARCHAR(8), + data bytea, last_token TEXT, last_success BIGINT, failing_since BIGINT, - FOREIGN KEY(user_name) REFERENCES users(name), - UNIQUE (app_id, pushkey, user_name) + UNIQUE (app_id, pushkey) ); INSERT INTO pushers2 (id, user_name, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since) SELECT id, user_name, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since FROM pushers; diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/schema/delta/25/fts.py index 5239d69073..ba48e43792 100644 --- a/synapse/storage/schema/delta/25/fts.py +++ b/synapse/storage/schema/delta/25/fts.py @@ -38,7 +38,7 @@ CREATE INDEX event_search_ev_ridx ON event_search(room_id); SQLITE_TABLE = ( - "CREATE VIRTUAL TABLE IF NOT EXISTS event_search" + "CREATE VIRTUAL TABLE event_search" " USING fts4 ( event_id, room_id, sender, key, value )" ) diff --git a/synapse/storage/schema/delta/26/account_data.sql b/synapse/storage/schema/delta/26/account_data.sql new file mode 100644 index 0000000000..3198a0d29c --- /dev/null +++ b/synapse/storage/schema/delta/26/account_data.sql @@ -0,0 +1,17 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +ALTER TABLE private_user_data_max_stream_id RENAME TO account_data_max_stream_id; diff --git a/synapse/storage/schema/delta/27/account_data.sql b/synapse/storage/schema/delta/27/account_data.sql new file mode 100644 index 0000000000..9f25416005 --- /dev/null +++ b/synapse/storage/schema/delta/27/account_data.sql @@ -0,0 +1,36 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS account_data( + user_id TEXT NOT NULL, + account_data_type TEXT NOT NULL, -- The type of the account_data. + stream_id BIGINT NOT NULL, -- The version of the account_data. + content TEXT NOT NULL, -- The JSON content of the account_data + CONSTRAINT account_data_uniqueness UNIQUE (user_id, account_data_type) +); + + +CREATE TABLE IF NOT EXISTS room_account_data( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + account_data_type TEXT NOT NULL, -- The type of the account_data. + stream_id BIGINT NOT NULL, -- The version of the account_data. + content TEXT NOT NULL, -- The JSON content of the account_data + CONSTRAINT room_account_data_uniqueness UNIQUE (user_id, room_id, account_data_type) +); + + +CREATE INDEX account_data_stream_id on account_data(user_id, stream_id); +CREATE INDEX room_account_data_stream_id on room_account_data(user_id, stream_id); diff --git a/synapse/storage/schema/delta/27/forgotten_memberships.sql b/synapse/storage/schema/delta/27/forgotten_memberships.sql new file mode 100644 index 0000000000..beeb8a288b --- /dev/null +++ b/synapse/storage/schema/delta/27/forgotten_memberships.sql @@ -0,0 +1,26 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Keeps track of what rooms users have left and don't want to be able to + * access again. + * + * If all users on this server have left a room, we can delete the room + * entirely. + * + * This column should always contain either 0 or 1. + */ + + ALTER TABLE room_memberships ADD COLUMN forgotten INTEGER DEFAULT 0; diff --git a/synapse/storage/schema/delta/27/ts.py b/synapse/storage/schema/delta/27/ts.py new file mode 100644 index 0000000000..8d4a981975 --- /dev/null +++ b/synapse/storage/schema/delta/27/ts.py @@ -0,0 +1,57 @@ +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.storage.prepare_database import get_statements + +import ujson + +logger = logging.getLogger(__name__) + + +ALTER_TABLE = ( + "ALTER TABLE events ADD COLUMN origin_server_ts BIGINT;" + "CREATE INDEX events_ts ON events(origin_server_ts, stream_ordering);" +) + + +def run_upgrade(cur, database_engine, *args, **kwargs): + for statement in get_statements(ALTER_TABLE.splitlines()): + cur.execute(statement) + + cur.execute("SELECT MIN(stream_ordering) FROM events") + rows = cur.fetchall() + min_stream_id = rows[0][0] + + cur.execute("SELECT MAX(stream_ordering) FROM events") + rows = cur.fetchall() + max_stream_id = rows[0][0] + + if min_stream_id is not None and max_stream_id is not None: + progress = { + "target_min_stream_id_inclusive": min_stream_id, + "max_stream_id_exclusive": max_stream_id + 1, + "rows_inserted": 0, + } + progress_json = ujson.dumps(progress) + + sql = ( + "INSERT into background_updates (update_name, progress_json)" + " VALUES (?, ?)" + ) + + sql = database_engine.convert_param_style(sql) + + cur.execute(sql, ("event_origin_server_ts", progress_json)) diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 380270b009..39f600f53c 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -20,6 +20,7 @@ from synapse.api.errors import SynapseError from synapse.storage.engines import PostgresEngine, Sqlite3Engine import logging +import re logger = logging.getLogger(__name__) @@ -139,7 +140,10 @@ class SearchStore(BackgroundUpdateStore): list of dicts """ clauses = [] - args = [] + + search_query = search_query = _parse_query(self.database_engine, search_term) + + args = [search_query] # Make sure we don't explode because the person is in too many rooms. # We filter the results below regardless. @@ -161,7 +165,7 @@ class SearchStore(BackgroundUpdateStore): if isinstance(self.database_engine, PostgresEngine): sql = ( "SELECT ts_rank_cd(vector, query) AS rank, room_id, event_id" - " FROM plainto_tsquery('english', ?) as query, event_search" + " FROM to_tsquery('english', ?) as query, event_search" " WHERE vector @@ query" ) elif isinstance(self.database_engine, Sqlite3Engine): @@ -182,7 +186,7 @@ class SearchStore(BackgroundUpdateStore): sql += " ORDER BY rank DESC LIMIT 500" results = yield self._execute( - "search_msgs", self.cursor_to_dict, sql, *([search_term] + args) + "search_msgs", self.cursor_to_dict, sql, *args ) results = filter(lambda row: row["room_id"] in room_ids, results) @@ -194,21 +198,28 @@ class SearchStore(BackgroundUpdateStore): for ev in events } - defer.returnValue([ - { - "event": event_map[r["event_id"]], - "rank": r["rank"], - } - for r in results - if r["event_id"] in event_map - ]) + highlights = None + if isinstance(self.database_engine, PostgresEngine): + highlights = yield self._find_highlights_in_postgres(search_query, events) + + defer.returnValue({ + "results": [ + { + "event": event_map[r["event_id"]], + "rank": r["rank"], + } + for r in results + if r["event_id"] in event_map + ], + "highlights": highlights, + }) @defer.inlineCallbacks - def search_room(self, room_id, search_term, keys, limit, pagination_token=None): + def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None): """Performs a full text search over events with given keys. Args: - room_id (str): The room_id to search in + room_id (list): The room_ids to search in search_term (str): Search term to search for keys (list): List of keys to search in, currently supports "content.body", "content.name", "content.topic" @@ -218,7 +229,18 @@ class SearchStore(BackgroundUpdateStore): list of dicts """ clauses = [] - args = [search_term, room_id] + + search_query = search_query = _parse_query(self.database_engine, search_term) + + args = [search_query] + + # Make sure we don't explode because the person is in too many rooms. + # We filter the results below regardless. + if len(room_ids) < 500: + clauses.append( + "room_id IN (%s)" % (",".join(["?"] * len(room_ids)),) + ) + args.extend(room_ids) local_clauses = [] for key in keys: @@ -231,25 +253,25 @@ class SearchStore(BackgroundUpdateStore): if pagination_token: try: - topo, stream = pagination_token.split(",") - topo = int(topo) + origin_server_ts, stream = pagination_token.split(",") + origin_server_ts = int(origin_server_ts) stream = int(stream) except: raise SynapseError(400, "Invalid pagination token") clauses.append( - "(topological_ordering < ?" - " OR (topological_ordering = ? AND stream_ordering < ?))" + "(origin_server_ts < ?" + " OR (origin_server_ts = ? AND stream_ordering < ?))" ) - args.extend([topo, topo, stream]) + args.extend([origin_server_ts, origin_server_ts, stream]) if isinstance(self.database_engine, PostgresEngine): sql = ( "SELECT ts_rank_cd(vector, query) as rank," - " topological_ordering, stream_ordering, room_id, event_id" - " FROM plainto_tsquery('english', ?) as query, event_search" + " origin_server_ts, stream_ordering, room_id, event_id" + " FROM to_tsquery('english', ?) as query, event_search" " NATURAL JOIN events" - " WHERE vector @@ query AND room_id = ?" + " WHERE vector @@ query AND " ) elif isinstance(self.database_engine, Sqlite3Engine): # We use CROSS JOIN here to ensure we use the right indexes. @@ -262,24 +284,23 @@ class SearchStore(BackgroundUpdateStore): # MATCH unless it uses the full text search index sql = ( "SELECT rank(matchinfo) as rank, room_id, event_id," - " topological_ordering, stream_ordering" + " origin_server_ts, stream_ordering" " FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo" " FROM event_search" " WHERE value MATCH ?" " )" " CROSS JOIN events USING (event_id)" - " WHERE room_id = ?" + " WHERE " ) else: # This should be unreachable. raise Exception("Unrecognized database engine") - for clause in clauses: - sql += " AND " + clause + sql += " AND ".join(clauses) # We add an arbitrary limit here to ensure we don't try to pull the # entire table from the database. - sql += " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?" + sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?" args.append(limit) @@ -287,6 +308,8 @@ class SearchStore(BackgroundUpdateStore): "search_rooms", self.cursor_to_dict, sql, *args ) + results = filter(lambda row: row["room_id"] in room_ids, results) + events = yield self._get_events([r["event_id"] for r in results]) event_map = { @@ -294,14 +317,110 @@ class SearchStore(BackgroundUpdateStore): for ev in events } - defer.returnValue([ - { - "event": event_map[r["event_id"]], - "rank": r["rank"], - "pagination_token": "%s,%s" % ( - r["topological_ordering"], r["stream_ordering"] - ), - } - for r in results - if r["event_id"] in event_map - ]) + highlights = None + if isinstance(self.database_engine, PostgresEngine): + highlights = yield self._find_highlights_in_postgres(search_query, events) + + defer.returnValue({ + "results": [ + { + "event": event_map[r["event_id"]], + "rank": r["rank"], + "pagination_token": "%s,%s" % ( + r["origin_server_ts"], r["stream_ordering"] + ), + } + for r in results + if r["event_id"] in event_map + ], + "highlights": highlights, + }) + + def _find_highlights_in_postgres(self, search_query, events): + """Given a list of events and a search term, return a list of words + that match from the content of the event. + + This is used to give a list of words that clients can match against to + highlight the matching parts. + + Args: + search_query (str) + events (list): A list of events + + Returns: + deferred : A set of strings. + """ + def f(txn): + highlight_words = set() + for event in events: + # As a hack we simply join values of all possible keys. This is + # fine since we're only using them to find possible highlights. + values = [] + for key in ("body", "name", "topic"): + v = event.content.get(key, None) + if v: + values.append(v) + + if not values: + continue + + value = " ".join(values) + + # We need to find some values for StartSel and StopSel that + # aren't in the value so that we can pick results out. + start_sel = "<" + stop_sel = ">" + + while start_sel in value: + start_sel += "<" + while stop_sel in value: + stop_sel += ">" + + query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % ( + _to_postgres_options({ + "StartSel": start_sel, + "StopSel": stop_sel, + "MaxFragments": "50", + }) + ) + txn.execute(query, (value, search_query,)) + headline, = txn.fetchall()[0] + + # Now we need to pick the possible highlights out of the haedline + # result. + matcher_regex = "%s(.*?)%s" % ( + re.escape(start_sel), + re.escape(stop_sel), + ) + + res = re.findall(matcher_regex, headline) + highlight_words.update([r.lower() for r in res]) + + return highlight_words + + return self.runInteraction("_find_highlights", f) + + +def _to_postgres_options(options_dict): + return "'%s'" % ( + ",".join("%s=%s" % (k, v) for k, v in options_dict.items()), + ) + + +def _parse_query(database_engine, search_term): + """Takes a plain unicode string from the user and converts it into a form + that can be passed to database. + We use this so that we can add prefix matching, which isn't something + that is supported by default. + """ + + # Pull out the individual words, discarding any non-word characters. + results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) + + if isinstance(database_engine, PostgresEngine): + return " & ".join(result + ":*" for result in results) + elif isinstance(database_engine, Sqlite3Engine): + return " & ".join(result + "*" for result in results) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index bf695b7800..f520f60c6c 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -28,17 +28,17 @@ class TagsStore(SQLBaseStore): def __init__(self, hs): super(TagsStore, self).__init__(hs) - self._private_user_data_id_gen = StreamIdGenerator( - "private_user_data_max_stream_id", "stream_id" + self._account_data_id_gen = StreamIdGenerator( + "account_data_max_stream_id", "stream_id" ) - def get_max_private_user_data_stream_id(self): + def get_max_account_data_stream_id(self): """Get the current max stream id for the private user data stream Returns: A deferred int. """ - return self._private_user_data_id_gen.get_max_token(self) + return self._account_data_id_gen.get_max_token(self) @cached() def get_tags_for_user(self, user_id): @@ -48,8 +48,8 @@ class TagsStore(SQLBaseStore): Args: user_id(str): The user to get the tags for. Returns: - A deferred dict mapping from room_id strings to lists of tag - strings. + A deferred dict mapping from room_id strings to dicts mapping from + tag strings to tag content. """ deferred = self._simple_select_list( @@ -144,12 +144,12 @@ class TagsStore(SQLBaseStore): ) self._update_revision_txn(txn, user_id, room_id, next_id) - with (yield self._private_user_data_id_gen.get_next(self)) as next_id: + with (yield self._account_data_id_gen.get_next(self)) as next_id: yield self.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) - result = yield self._private_user_data_id_gen.get_max_token(self) + result = yield self._account_data_id_gen.get_max_token(self) defer.returnValue(result) @defer.inlineCallbacks @@ -166,12 +166,12 @@ class TagsStore(SQLBaseStore): txn.execute(sql, (user_id, room_id, tag)) self._update_revision_txn(txn, user_id, room_id, next_id) - with (yield self._private_user_data_id_gen.get_next(self)) as next_id: + with (yield self._account_data_id_gen.get_next(self)) as next_id: yield self.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) - result = yield self._private_user_data_id_gen.get_max_token(self) + result = yield self._account_data_id_gen.get_max_token(self) defer.returnValue(result) def _update_revision_txn(self, txn, user_id, room_id, next_id): @@ -185,7 +185,7 @@ class TagsStore(SQLBaseStore): """ update_max_id_sql = ( - "UPDATE private_user_data_max_stream_id" + "UPDATE account_data_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?" ) diff --git a/synapse/streams/events.py b/synapse/streams/events.py index f0d68b5bf2..cfa7d30fa5 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -21,7 +21,7 @@ from synapse.handlers.presence import PresenceEventSource from synapse.handlers.room import RoomEventSource from synapse.handlers.typing import TypingNotificationEventSource from synapse.handlers.receipts import ReceiptEventSource -from synapse.handlers.private_user_data import PrivateUserDataEventSource +from synapse.handlers.account_data import AccountDataEventSource class EventSources(object): @@ -30,7 +30,7 @@ class EventSources(object): "presence": PresenceEventSource, "typing": TypingNotificationEventSource, "receipt": ReceiptEventSource, - "private_user_data": PrivateUserDataEventSource, + "account_data": AccountDataEventSource, } def __init__(self, hs): @@ -54,8 +54,8 @@ class EventSources(object): receipt_key=( yield self.sources["receipt"].get_current_key() ), - private_user_data_key=( - yield self.sources["private_user_data"].get_current_key() + account_data_key=( + yield self.sources["account_data"].get_current_key() ), ) defer.returnValue(token) diff --git a/synapse/types.py b/synapse/types.py index 28344d8b36..af1d76ab46 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -103,7 +103,7 @@ class StreamToken( "presence_key", "typing_key", "receipt_key", - "private_user_data_key", + "account_data_key", )) ): _SEPARATOR = "_" @@ -138,7 +138,7 @@ class StreamToken( or (int(other.presence_key) < int(self.presence_key)) or (int(other.typing_key) < int(self.typing_key)) or (int(other.receipt_key) < int(self.receipt_key)) - or (int(other.private_user_data_key) < int(self.private_user_data_key)) + or (int(other.account_data_key) < int(self.account_data_key)) ) def copy_and_advance(self, key, new_value): diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index d69c7cb991..2170746025 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -64,8 +64,7 @@ class Clock(object): current_context = LoggingContext.current_context() def wrapped_callback(*args, **kwargs): - with PreserveLoggingContext(): - LoggingContext.thread_local.current_context = current_context + with PreserveLoggingContext(current_context): callback(*args, **kwargs) with PreserveLoggingContext(): diff --git a/synapse/util/debug.py b/synapse/util/debug.py index f6a5a841a4..b2bee7958f 100644 --- a/synapse/util/debug.py +++ b/synapse/util/debug.py @@ -30,8 +30,7 @@ def debug_deferreds(): context = LoggingContext.current_context() def restore_context_callback(x): - with PreserveLoggingContext(): - LoggingContext.thread_local.current_context = context + with PreserveLoggingContext(context): return fn(x) return restore_context_callback diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index 7e6062c1b8..d528ced55a 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -19,6 +19,25 @@ import logging logger = logging.getLogger(__name__) +try: + import resource + + # Python doesn't ship with a definition of RUSAGE_THREAD but it's defined + # to be 1 on linux so we hard code it. + RUSAGE_THREAD = 1 + + # If the system doesn't support RUSAGE_THREAD then this should throw an + # exception. + resource.getrusage(RUSAGE_THREAD) + + def get_thread_resource_usage(): + return resource.getrusage(RUSAGE_THREAD) +except: + # If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we + # won't track resource usage by returning None. + def get_thread_resource_usage(): + return None + class LoggingContext(object): """Additional context for log formatting. Contexts are scoped within a @@ -27,7 +46,9 @@ class LoggingContext(object): name (str): Name for the context for debugging. """ - __slots__ = ["parent_context", "name", "__dict__"] + __slots__ = [ + "parent_context", "name", "usage_start", "usage_end", "main_thread", "__dict__" + ] thread_local = threading.local() @@ -42,11 +63,26 @@ class LoggingContext(object): def copy_to(self, record): pass + def start(self): + pass + + def stop(self): + pass + + def add_database_transaction(self, duration_ms): + pass + sentinel = Sentinel() def __init__(self, name=None): self.parent_context = None self.name = name + self.ru_stime = 0. + self.ru_utime = 0. + self.db_txn_count = 0 + self.db_txn_duration = 0. + self.usage_start = None + self.main_thread = threading.current_thread() def __str__(self): return "%s@%x" % (self.name, id(self)) @@ -56,12 +92,26 @@ class LoggingContext(object): """Get the current logging context from thread local storage""" return getattr(cls.thread_local, "current_context", cls.sentinel) + @classmethod + def set_current_context(cls, context): + """Set the current logging context in thread local storage + Args: + context(LoggingContext): The context to activate. + Returns: + The context that was previously active + """ + current = cls.current_context() + if current is not context: + current.stop() + cls.thread_local.current_context = context + context.start() + return current + def __enter__(self): """Enters this logging context into thread local storage""" if self.parent_context is not None: raise Exception("Attempt to enter logging context multiple times") - self.parent_context = self.current_context() - self.thread_local.current_context = self + self.parent_context = self.set_current_context(self) return self def __exit__(self, type, value, traceback): @@ -70,16 +120,16 @@ class LoggingContext(object): Returns: None to avoid suppressing any exeptions that were thrown. """ - if self.thread_local.current_context is not self: - if self.thread_local.current_context is self.sentinel: + current = self.set_current_context(self.parent_context) + if current is not self: + if current is self.sentinel: logger.debug("Expected logging context %s has been lost", self) else: logger.warn( "Current logging context %s is not expected context %s", - self.thread_local.current_context, + current, self ) - self.thread_local.current_context = self.parent_context self.parent_context = None def __getattr__(self, name): @@ -93,6 +143,43 @@ class LoggingContext(object): for key, value in self.__dict__.items(): setattr(record, key, value) + record.ru_utime, record.ru_stime = self.get_resource_usage() + + def start(self): + if threading.current_thread() is not self.main_thread: + return + + if self.usage_start and self.usage_end: + self.ru_utime += self.usage_end.ru_utime - self.usage_start.ru_utime + self.ru_stime += self.usage_end.ru_stime - self.usage_start.ru_stime + self.usage_start = None + self.usage_end = None + + if not self.usage_start: + self.usage_start = get_thread_resource_usage() + + def stop(self): + if threading.current_thread() is not self.main_thread: + return + + if self.usage_start: + self.usage_end = get_thread_resource_usage() + + def get_resource_usage(self): + ru_utime = self.ru_utime + ru_stime = self.ru_stime + + if self.usage_start and threading.current_thread() is self.main_thread: + current = get_thread_resource_usage() + ru_utime += current.ru_utime - self.usage_start.ru_utime + ru_stime += current.ru_stime - self.usage_start.ru_stime + + return ru_utime, ru_stime + + def add_database_transaction(self, duration_ms): + self.db_txn_count += 1 + self.db_txn_duration += duration_ms / 1000. + class LoggingContextFilter(logging.Filter): """Logging filter that adds values from the current logging context to each @@ -121,17 +208,20 @@ class PreserveLoggingContext(object): exited. Used to restore the context after a function using @defer.inlineCallbacks is resumed by a callback from the reactor.""" - __slots__ = ["current_context"] + __slots__ = ["current_context", "new_context"] + + def __init__(self, new_context=LoggingContext.sentinel): + self.new_context = new_context def __enter__(self): """Captures the current logging context""" - self.current_context = LoggingContext.current_context() - LoggingContext.thread_local.current_context = LoggingContext.sentinel + self.current_context = LoggingContext.set_current_context( + self.new_context + ) def __exit__(self, type, value, traceback): """Restores the current logging context""" - LoggingContext.thread_local.current_context = self.current_context - + LoggingContext.set_current_context(self.current_context) if self.current_context is not LoggingContext.sentinel: if self.current_context.parent_context is None: logger.warn( @@ -164,8 +254,7 @@ class _PreservingContextDeferred(defer.Deferred): def _wrap_callback(self, f): def g(res, *args, **kwargs): - with PreserveLoggingContext(): - LoggingContext.thread_local.current_context = self._log_context + with PreserveLoggingContext(self._log_context): res = f(res, *args, **kwargs) return res return g |