diff options
Diffstat (limited to 'synapse')
29 files changed, 552 insertions, 165 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index f68a15bb85..7ff37edf2c 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.11.0-rc1" +__version__ = "0.11.0-r2" diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 3e891a6193..4fdc779b4b 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -207,6 +207,13 @@ class Auth(object): user_id, room_id )) + if membership == Membership.LEAVE: + forgot = yield self.store.did_forget(user_id, room_id) + if forgot: + raise AuthError(403, "User %s not in room %s" % ( + user_id, room_id + )) + defer.returnValue(member) @defer.inlineCallbacks @@ -587,7 +594,7 @@ class Auth(object): def _get_user_from_macaroon(self, macaroon_str): try: macaroon = pymacaroons.Macaroon.deserialize(macaroon_str) - self._validate_macaroon(macaroon) + self.validate_macaroon(macaroon, "access", False) user_prefix = "user_id = " user = None @@ -635,13 +642,27 @@ class Auth(object): errcode=Codes.UNKNOWN_TOKEN ) - def _validate_macaroon(self, macaroon): + def validate_macaroon(self, macaroon, type_string, verify_expiry): + """ + validate that a Macaroon is understood by and was signed by this server. + + Args: + macaroon(pymacaroons.Macaroon): The macaroon to validate + type_string(str): The kind of token this is (e.g. "access", "refresh") + verify_expiry(bool): Whether to verify whether the macaroon has expired. + This should really always be True, but no clients currently implement + token refresh, so we can't enforce expiry yet. + """ v = pymacaroons.Verifier() v.satisfy_exact("gen = 1") - v.satisfy_exact("type = access") + v.satisfy_exact("type = " + type_string) v.satisfy_general(lambda c: c.startswith("user_id = ")) - v.satisfy_general(self._verify_expiry) v.satisfy_exact("guest = true") + if verify_expiry: + v.satisfy_general(self._verify_expiry) + else: + v.satisfy_general(lambda c: c.startswith("time < ")) + v.verify(macaroon, self.hs.config.macaroon_secret_key) v = pymacaroons.Verifier() @@ -652,9 +673,6 @@ class Auth(object): prefix = "time < " if not caveat.startswith(prefix): return False - # TODO(daniel): Enable expiry check when clients actually know how to - # refresh tokens. (And remember to enable the tests) - return True expiry = int(caveat[len(prefix):]) now = self.hs.get_clock().time_msec() return now < expiry diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index aaa2433cae..18f2ec3ae8 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -54,7 +54,7 @@ class Filtering(object): ] room_level_definitions = [ - "state", "timeline", "ephemeral", "private_user_data" + "state", "timeline", "ephemeral", "account_data" ] for key in top_level_definitions: @@ -131,8 +131,8 @@ class FilterCollection(object): self.filter_json.get("room", {}).get("ephemeral", {}) ) - self.room_private_user_data = Filter( - self.filter_json.get("room", {}).get("private_user_data", {}) + self.room_account_data = Filter( + self.filter_json.get("room", {}).get("account_data", {}) ) self.presence_filter = Filter( @@ -160,8 +160,8 @@ class FilterCollection(object): def filter_room_ephemeral(self, events): return self.room_ephemeral_filter.filter(events) - def filter_room_private_user_data(self, events): - return self.room_private_user_data.filter(events) + def filter_room_account_data(self, events): + return self.room_account_data.filter(events) class Filter(object): diff --git a/synapse/config/_base.py b/synapse/config/_base.py index c18e0bdbb8..d0c9972445 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -25,18 +25,29 @@ class ConfigError(Exception): pass -class Config(object): +# We split these messages out to allow packages to override with package +# specific instructions. +MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS = """\ +Please opt in or out of reporting anonymized homeserver usage statistics, by +setting the `report_stats` key in your config file to either True or False. +""" + +MISSING_REPORT_STATS_SPIEL = """\ +We would really appreciate it if you could help our project out by reporting +anonymized usage statistics from your homeserver. Only very basic aggregate +data (e.g. number of users) will be reported, but it helps us to track the +growth of the Matrix community, and helps us to make Matrix a success, as well +as to convince other networks that they should peer with us. + +Thank you. +""" + +MISSING_SERVER_NAME = """\ +Missing mandatory `server_name` config option. +""" - stats_reporting_begging_spiel = ( - "We would really appreciate it if you could help our project out by" - " reporting anonymized usage statistics from your homeserver. Only very" - " basic aggregate data (e.g. number of users) will be reported, but it" - " helps us to track the growth of the Matrix community, and helps us to" - " make Matrix a success, as well as to convince other networks that they" - " should peer with us." - "\nThank you." - ) +class Config(object): @staticmethod def parse_size(value): if isinstance(value, int) or isinstance(value, long): @@ -215,7 +226,7 @@ class Config(object): if config_args.report_stats is None: config_parser.error( "Please specify either --report-stats=yes or --report-stats=no\n\n" + - cls.stats_reporting_begging_spiel + MISSING_REPORT_STATS_SPIEL ) if not config_files: config_parser.error( @@ -290,6 +301,10 @@ class Config(object): yaml_config = cls.read_config_file(config_file) specified_config.update(yaml_config) + if "server_name" not in specified_config: + sys.stderr.write("\n" + MISSING_SERVER_NAME + "\n") + sys.exit(1) + server_name = specified_config["server_name"] _, config = obj.generate_config( config_dir_path=config_dir_path, @@ -299,11 +314,8 @@ class Config(object): config.update(specified_config) if "report_stats" not in config: sys.stderr.write( - "Please opt in or out of reporting anonymized homeserver usage " - "statistics, by setting the report_stats key in your config file " - " ( " + config_path + " ) " + - "to either True or False.\n\n" + - Config.stats_reporting_begging_spiel + "\n") + "\n" + MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" + + MISSING_REPORT_STATS_SPIEL + "\n") sys.exit(1) if generate_keys: diff --git a/synapse/config/cas.py b/synapse/config/cas.py index a337ae6ca0..326e405841 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -27,10 +27,12 @@ class CasConfig(Config): if cas_config: self.cas_enabled = cas_config.get("enabled", True) self.cas_server_url = cas_config["server_url"] + self.cas_service_url = cas_config["service_url"] self.cas_required_attributes = cas_config.get("required_attributes", {}) else: self.cas_enabled = False self.cas_server_url = None + self.cas_service_url = None self.cas_required_attributes = {} def default_config(self, config_dir_path, server_name, **kwargs): @@ -39,6 +41,7 @@ class CasConfig(Config): #cas_config: # enabled: true # server_url: "https://cas-server.com" + # service_url: "https://homesever.domain.com:8448" # #required_attributes: # # name: value """ diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 8b6a59866f..bc5bb5cdb1 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -381,28 +381,24 @@ class Keyring(object): def get_server_verify_key_v2_indirect(self, server_names_and_key_ids, perspective_name, perspective_keys): - limiter = yield get_retry_limiter( - perspective_name, self.clock, self.store - ) - - with limiter: - # TODO(mark): Set the minimum_valid_until_ts to that needed by - # the events being validated or the current time if validating - # an incoming request. - query_response = yield self.client.post_json( - destination=perspective_name, - path=b"/_matrix/key/v2/query", - data={ - u"server_keys": { - server_name: { - key_id: { - u"minimum_valid_until_ts": 0 - } for key_id in key_ids - } - for server_name, key_ids in server_names_and_key_ids + # TODO(mark): Set the minimum_valid_until_ts to that needed by + # the events being validated or the current time if validating + # an incoming request. + query_response = yield self.client.post_json( + destination=perspective_name, + path=b"/_matrix/key/v2/query", + data={ + u"server_keys": { + server_name: { + key_id: { + u"minimum_valid_until_ts": 0 + } for key_id in key_ids } - }, - ) + for server_name, key_ids in server_names_and_key_ids + } + }, + long_retries=True, + ) keys = {} diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 9989b76591..44cc1ef132 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -129,10 +129,9 @@ def format_event_for_client_v2(d): return d -def format_event_for_client_v2_without_event_id(d): +def format_event_for_client_v2_without_room_id(d): d = format_event_for_client_v2(d) d.pop("room_id", None) - d.pop("event_id", None) return d diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index d4f586fae7..c6a8c1249a 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -401,6 +401,12 @@ class FederationClient(FederationBase): pdu_dict["content"].update(content) + # The protoevent received over the JSON wire may not have all + # the required fields. Lets just gloss over that because + # there's some we never care about + if "prev_state" not in pdu_dict: + pdu_dict["prev_state"] = [] + defer.returnValue( (destination, self.event_from_pdu_json(pdu_dict)) ) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 3d59e1c650..0e0cb7ebc6 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -136,6 +136,7 @@ class TransportLayerClient(object): path=PREFIX + "/send/%s/" % transaction.transaction_id, data=json_data, json_data_callback=json_data_callback, + long_retries=True, ) logger.debug( diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 6519f183df..5fd20285d2 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -92,7 +92,15 @@ class BaseHandler(object): membership_event = state.get((EventTypes.Member, user_id), None) if membership_event: - membership = membership_event.membership + was_forgotten_at_event = yield self.store.was_forgotten_at( + membership_event.state_key, + membership_event.room_id, + membership_event.event_id + ) + if was_forgotten_at_event: + membership = None + else: + membership = membership_event.membership else: membership = None diff --git a/synapse/handlers/private_user_data.py b/synapse/handlers/account_data.py index 1abe45ed7b..1d35d3b7dc 100644 --- a/synapse/handlers/private_user_data.py +++ b/synapse/handlers/account_data.py @@ -16,19 +16,19 @@ from twisted.internet import defer -class PrivateUserDataEventSource(object): +class AccountDataEventSource(object): def __init__(self, hs): self.store = hs.get_datastore() def get_current_key(self, direction='f'): - return self.store.get_max_private_user_data_stream_id() + return self.store.get_max_account_data_stream_id() @defer.inlineCallbacks def get_new_events(self, user, from_key, **kwargs): user_id = user.to_string() last_stream_id = from_key - current_stream_id = yield self.store.get_max_private_user_data_stream_id() + current_stream_id = yield self.store.get_max_account_data_stream_id() tags = yield self.store.get_updated_tags(user_id, last_stream_id) results = [] diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 1b11dbdffd..e64b67cdfd 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -18,7 +18,7 @@ from twisted.internet import defer from ._base import BaseHandler from synapse.api.constants import LoginType from synapse.types import UserID -from synapse.api.errors import LoginError, Codes +from synapse.api.errors import AuthError, LoginError, Codes from synapse.util.async import run_on_reactor from twisted.web.client import PartialDownloadError @@ -46,6 +46,7 @@ class AuthHandler(BaseHandler): } self.bcrypt_rounds = hs.config.bcrypt_rounds self.sessions = {} + self.INVALID_TOKEN_HTTP_STATUS = 401 @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): @@ -297,10 +298,11 @@ class AuthHandler(BaseHandler): defer.returnValue((user_id, access_token, refresh_token)) @defer.inlineCallbacks - def login_with_cas_user_id(self, user_id): + def get_login_tuple_for_user_id(self, user_id): """ - Authenticates the user with the given user ID, - intended to have been captured from a CAS response + Gets login tuple for the user with the given user ID. + The user is assumed to have been authenticated by some other + machanism (e.g. CAS) Args: user_id (str): User ID @@ -393,6 +395,23 @@ class AuthHandler(BaseHandler): )) return m.serialize() + def generate_short_term_login_token(self, user_id): + macaroon = self._generate_base_macaroon(user_id) + macaroon.add_first_party_caveat("type = login") + now = self.hs.get_clock().time_msec() + expiry = now + (2 * 60 * 1000) + macaroon.add_first_party_caveat("time < %d" % (expiry,)) + return macaroon.serialize() + + def validate_short_term_login_token_and_get_user_id(self, login_token): + try: + macaroon = pymacaroons.Macaroon.deserialize(login_token) + auth_api = self.hs.get_auth() + auth_api.validate_macaroon(macaroon, "login", True) + return self._get_user_from_macaroon(macaroon) + except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): + raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN) + def _generate_base_macaroon(self, user_id): macaroon = pymacaroons.Macaroon( location=self.hs.config.server_name, @@ -402,6 +421,16 @@ class AuthHandler(BaseHandler): macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) return macaroon + def _get_user_from_macaroon(self, macaroon): + user_prefix = "user_id = " + for caveat in macaroon.caveats: + if caveat.caveat_id.startswith(user_prefix): + return caveat.caveat_id[len(user_prefix):] + raise AuthError( + self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token", + errcode=Codes.UNKNOWN_TOKEN + ) + @defer.inlineCallbacks def set_password(self, user_id, newpassword): password_hash = self.hash(newpassword) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index a92409c6a2..64c57375f7 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -436,14 +436,14 @@ class MessageHandler(BaseHandler): for c in current_state.values() ] - private_user_data = [] + account_data = [] tags = tags_by_room.get(event.room_id) if tags: - private_user_data.append({ + account_data.append({ "type": "m.tag", "content": {"tags": tags}, }) - d["private_user_data"] = private_user_data + d["account_data"] = account_data except: logger.exception("Failed to get snapshot") @@ -498,14 +498,14 @@ class MessageHandler(BaseHandler): user_id, room_id, pagin_config, membership, member_event_id, is_guest ) - private_user_data = [] + account_data = [] tags = yield self.store.get_tags_for_room(user_id, room_id) if tags: - private_user_data.append({ + account_data.append({ "type": "m.tag", "content": {"tags": tags}, }) - result["private_user_data"] = private_user_data + result["account_data"] = account_data defer.returnValue(result) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 3f04752581..023b4001b8 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -743,6 +743,9 @@ class RoomMemberHandler(BaseHandler): ) defer.returnValue((token, public_key, key_validity_url, display_name)) + def forget(self, user, room_id): + self.store.forget(user.to_string(), room_id) + class RoomListHandler(BaseHandler): diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index b7545c111f..50688e51a8 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -17,13 +17,14 @@ from twisted.internet import defer from ._base import BaseHandler -from synapse.api.constants import Membership +from synapse.api.constants import Membership, EventTypes from synapse.api.filtering import Filter from synapse.api.errors import SynapseError from synapse.events.utils import serialize_event from unpaddedbase64 import decode_base64, encode_base64 +import itertools import logging @@ -79,6 +80,9 @@ class SearchHandler(BaseHandler): # What to order results by (impacts whether pagination can be doen) order_by = room_cat.get("order_by", "rank") + # Return the current state of the rooms? + include_state = room_cat.get("include_state", False) + # Include context around each event? event_context = room_cat.get( "event_context", None @@ -96,6 +100,10 @@ class SearchHandler(BaseHandler): after_limit = int(event_context.get( "after_limit", 5 )) + + # Return the historic display name and avatar for the senders + # of the events? + include_profile = bool(event_context.get("include_profile", False)) except KeyError: raise SynapseError(400, "Invalid search query") @@ -269,6 +277,33 @@ class SearchHandler(BaseHandler): "room_key", res["end"] ).to_string() + if include_profile: + senders = set( + ev.sender + for ev in itertools.chain( + res["events_before"], [event], res["events_after"] + ) + ) + + if res["events_after"]: + last_event_id = res["events_after"][-1].event_id + else: + last_event_id = event.event_id + + state = yield self.store.get_state_for_event( + last_event_id, + types=[(EventTypes.Member, sender) for sender in senders] + ) + + res["profile_info"] = { + s.state_key: { + "displayname": s.content.get("displayname", None), + "avatar_url": s.content.get("avatar_url", None), + } + for s in state.values() + if s.type == EventTypes.Member and s.state_key in senders + } + contexts[event.event_id] = res else: contexts = {} @@ -287,6 +322,18 @@ class SearchHandler(BaseHandler): for e in context["events_after"] ] + state_results = {} + if include_state: + rooms = set(e.room_id for e in allowed_events) + for room_id in rooms: + state = yield self.state_handler.get_current_state(room_id) + state_results[room_id] = state.values() + + state_results.values() + + # We're now about to serialize the events. We should not make any + # blocking calls after this. Otherwise the 'age' will be wrong + results = { e.event_id: { "rank": rank_map[e.event_id], @@ -303,6 +350,12 @@ class SearchHandler(BaseHandler): "count": len(results) } + if state_results: + rooms_cat_res["state"] = { + room_id: [serialize_event(e, time_now) for e in state] + for room_id, state in state_results.items() + } + if room_groups and "room_id" in group_keys: rooms_cat_res.setdefault("groups", {})["room_id"] = room_groups diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 6dc9d0fb92..877328b29e 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -51,7 +51,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [ "timeline", # TimelineBatch "state", # dict[(str, str), FrozenEvent] "ephemeral", - "private_user_data", + "account_data", ])): __slots__ = [] @@ -63,7 +63,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [ self.timeline or self.state or self.ephemeral - or self.private_user_data + or self.account_data ) @@ -71,7 +71,7 @@ class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [ "room_id", # str "timeline", # TimelineBatch "state", # dict[(str, str), FrozenEvent] - "private_user_data", + "account_data", ])): __slots__ = [] @@ -82,7 +82,7 @@ class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [ return bool( self.timeline or self.state - or self.private_user_data + or self.account_data ) @@ -261,20 +261,20 @@ class SyncHandler(BaseHandler): timeline=batch, state=current_state, ephemeral=ephemeral_by_room.get(room_id, []), - private_user_data=self.private_user_data_for_room( + account_data=self.account_data_for_room( room_id, tags_by_room ), )) - def private_user_data_for_room(self, room_id, tags_by_room): - private_user_data = [] + def account_data_for_room(self, room_id, tags_by_room): + account_data = [] tags = tags_by_room.get(room_id) if tags is not None: - private_user_data.append({ + account_data.append({ "type": "m.tag", "content": {"tags": tags}, }) - return private_user_data + return account_data @defer.inlineCallbacks def ephemeral_by_room(self, sync_config, now_token, since_token=None): @@ -357,7 +357,7 @@ class SyncHandler(BaseHandler): room_id=room_id, timeline=batch, state=leave_state, - private_user_data=self.private_user_data_for_room( + account_data=self.account_data_for_room( room_id, tags_by_room ), )) @@ -412,7 +412,7 @@ class SyncHandler(BaseHandler): tags_by_room = yield self.store.get_updated_tags( sync_config.user.to_string(), - since_token.private_user_data_key, + since_token.account_data_key, ) joined = [] @@ -468,7 +468,7 @@ class SyncHandler(BaseHandler): ), state=state, ephemeral=ephemeral_by_room.get(room_id, []), - private_user_data=self.private_user_data_for_room( + account_data=self.account_data_for_room( room_id, tags_by_room ), ) @@ -605,7 +605,7 @@ class SyncHandler(BaseHandler): timeline=batch, state=state, ephemeral=ephemeral_by_room.get(room_id, []), - private_user_data=self.private_user_data_for_room( + account_data=self.account_data_for_room( room_id, tags_by_room ), ) @@ -653,7 +653,7 @@ class SyncHandler(BaseHandler): room_id=leave_event.room_id, timeline=batch, state=state_events_delta, - private_user_data=self.private_user_data_for_room( + account_data=self.account_data_for_room( leave_event.room_id, tags_by_room ), ) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 6e53538a52..b7b7c2cce8 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -56,7 +56,8 @@ incoming_responses_counter = metrics.register_counter( ) -MAX_RETRIES = 4 +MAX_LONG_RETRIES = 10 +MAX_SHORT_RETRIES = 3 class MatrixFederationEndpointFactory(object): @@ -103,7 +104,7 @@ class MatrixFederationHttpClient(object): def _create_request(self, destination, method, path_bytes, body_callback, headers_dict={}, param_bytes=b"", query_bytes=b"", retry_on_dns_fail=True, - timeout=None): + timeout=None, long_retries=False): """ Creates and sends a request to the given url """ headers_dict[b"User-Agent"] = [self.version_string] @@ -123,7 +124,10 @@ class MatrixFederationHttpClient(object): # XXX: Would be much nicer to retry only at the transaction-layer # (once we have reliable transactions in place) - retries_left = MAX_RETRIES + if long_retries: + retries_left = MAX_LONG_RETRIES + else: + retries_left = MAX_SHORT_RETRIES http_url_bytes = urlparse.urlunparse( ("", "", path_bytes, param_bytes, query_bytes, "") @@ -184,8 +188,15 @@ class MatrixFederationHttpClient(object): ) if retries_left and not timeout: - delay = 5 ** (MAX_RETRIES + 1 - retries_left) - delay *= random.uniform(0.8, 1.4) + if long_retries: + delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left) + delay = min(delay, 60) + delay *= random.uniform(0.8, 1.4) + else: + delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left) + delay = min(delay, 2) + delay *= random.uniform(0.8, 1.4) + yield sleep(delay) retries_left -= 1 else: @@ -236,7 +247,8 @@ class MatrixFederationHttpClient(object): headers_dict[b"Authorization"] = auth_headers @defer.inlineCallbacks - def put_json(self, destination, path, data={}, json_data_callback=None): + def put_json(self, destination, path, data={}, json_data_callback=None, + long_retries=False): """ Sends the specifed json data using PUT Args: @@ -247,6 +259,8 @@ class MatrixFederationHttpClient(object): the request body. This will be encoded as JSON. json_data_callback (callable): A callable returning the dict to use as the request body. + long_retries (bool): A boolean that indicates whether we should + retry for a short or long time. Returns: Deferred: Succeeds when we get a 2xx HTTP response. The result @@ -272,6 +286,7 @@ class MatrixFederationHttpClient(object): path.encode("ascii"), body_callback=body_callback, headers_dict={"Content-Type": ["application/json"]}, + long_retries=long_retries, ) if 200 <= response.code < 300: @@ -287,7 +302,7 @@ class MatrixFederationHttpClient(object): defer.returnValue(json.loads(body)) @defer.inlineCallbacks - def post_json(self, destination, path, data={}): + def post_json(self, destination, path, data={}, long_retries=True): """ Sends the specifed json data using POST Args: @@ -296,6 +311,8 @@ class MatrixFederationHttpClient(object): path (str): The HTTP path. data (dict): A dict containing the data that will be used as the request body. This will be encoded as JSON. + long_retries (bool): A boolean that indicates whether we should + retry for a short or long time. Returns: Deferred: Succeeds when we get a 2xx HTTP response. The result @@ -315,6 +332,7 @@ class MatrixFederationHttpClient(object): path.encode("ascii"), body_callback=body_callback, headers_dict={"Content-Type": ["application/json"]}, + long_retries=True, ) if 200 <= response.code < 300: @@ -490,6 +508,9 @@ class _JsonProducer(object): def stopProducing(self): pass + def resumeProducing(self): + pass + def _flatten_response_never_received(e): if hasattr(e, "reasons"): diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 4ea06c1434..720d6358e7 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -22,6 +22,7 @@ from base import ClientV1RestServlet, client_path_pattern import simplejson as json import urllib +import urlparse import logging from saml2 import BINDING_HTTP_POST @@ -39,6 +40,7 @@ class LoginRestServlet(ClientV1RestServlet): PASS_TYPE = "m.login.password" SAML2_TYPE = "m.login.saml2" CAS_TYPE = "m.login.cas" + TOKEN_TYPE = "m.login.token" def __init__(self, hs): super(LoginRestServlet, self).__init__(hs) @@ -56,8 +58,18 @@ class LoginRestServlet(ClientV1RestServlet): flows.append({"type": LoginRestServlet.SAML2_TYPE}) if self.cas_enabled: flows.append({"type": LoginRestServlet.CAS_TYPE}) + + # While its valid for us to advertise this login type generally, + # synapse currently only gives out these tokens as part of the + # CAS login flow. + # Generally we don't want to advertise login flows that clients + # don't know how to implement, since they (currently) will always + # fall back to the fallback API if they don't understand one of the + # login flow types returned. + flows.append({"type": LoginRestServlet.TOKEN_TYPE}) if self.password_enabled: flows.append({"type": LoginRestServlet.PASS_TYPE}) + return (200, {"flows": flows}) def on_OPTIONS(self, request): @@ -83,6 +95,7 @@ class LoginRestServlet(ClientV1RestServlet): "uri": "%s%s" % (self.idp_redirect_url, relay_state) } defer.returnValue((200, result)) + # TODO Delete this after all CAS clients switch to token login instead elif self.cas_enabled and (login_submission["type"] == LoginRestServlet.CAS_TYPE): # TODO: get this from the homeserver rather than creating a new one for @@ -96,6 +109,9 @@ class LoginRestServlet(ClientV1RestServlet): body = yield http_client.get_raw(uri, args) result = yield self.do_cas_login(body) defer.returnValue(result) + elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: + result = yield self.do_token_login(login_submission) + defer.returnValue(result) else: raise SynapseError(400, "Bad login type.") except KeyError: @@ -132,6 +148,26 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue((200, result)) @defer.inlineCallbacks + def do_token_login(self, login_submission): + token = login_submission['token'] + auth_handler = self.handlers.auth_handler + user_id = ( + yield auth_handler.validate_short_term_login_token_and_get_user_id(token) + ) + user_id, access_token, refresh_token = ( + yield auth_handler.get_login_tuple_for_user_id(user_id) + ) + result = { + "user_id": user_id, # may have changed + "access_token": access_token, + "refresh_token": refresh_token, + "home_server": self.hs.hostname, + } + + defer.returnValue((200, result)) + + # TODO Delete this after all CAS clients switch to token login instead + @defer.inlineCallbacks def do_cas_login(self, cas_response_body): user, attributes = self.parse_cas_response(cas_response_body) @@ -152,7 +188,7 @@ class LoginRestServlet(ClientV1RestServlet): user_exists = yield auth_handler.does_user_exist(user_id) if user_exists: user_id, access_token, refresh_token = ( - yield auth_handler.login_with_cas_user_id(user_id) + yield auth_handler.get_login_tuple_for_user_id(user_id) ) result = { "user_id": user_id, # may have changed @@ -173,6 +209,7 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue((200, result)) + # TODO Delete this after all CAS clients switch to token login instead def parse_cas_response(self, cas_response_body): root = ET.fromstring(cas_response_body) if not root.tag.endswith("serviceResponse"): @@ -243,6 +280,7 @@ class SAML2RestServlet(ClientV1RestServlet): defer.returnValue((200, {"status": "not_authenticated"})) +# TODO Delete this after all CAS clients switch to token login instead class CasRestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/login/cas") @@ -254,6 +292,115 @@ class CasRestServlet(ClientV1RestServlet): return (200, {"serverUrl": self.cas_server_url}) +class CasRedirectServlet(ClientV1RestServlet): + PATTERN = client_path_pattern("/login/cas/redirect") + + def __init__(self, hs): + super(CasRedirectServlet, self).__init__(hs) + self.cas_server_url = hs.config.cas_server_url + self.cas_service_url = hs.config.cas_service_url + + def on_GET(self, request): + args = request.args + if "redirectUrl" not in args: + return (400, "Redirect URL not specified for CAS auth") + client_redirect_url_param = urllib.urlencode({ + "redirectUrl": args["redirectUrl"][0] + }) + hs_redirect_url = self.cas_service_url + "/_matrix/client/api/v1/login/cas/ticket" + service_param = urllib.urlencode({ + "service": "%s?%s" % (hs_redirect_url, client_redirect_url_param) + }) + request.redirect("%s?%s" % (self.cas_server_url, service_param)) + request.finish() + + +class CasTicketServlet(ClientV1RestServlet): + PATTERN = client_path_pattern("/login/cas/ticket") + + def __init__(self, hs): + super(CasTicketServlet, self).__init__(hs) + self.cas_server_url = hs.config.cas_server_url + self.cas_service_url = hs.config.cas_service_url + self.cas_required_attributes = hs.config.cas_required_attributes + + @defer.inlineCallbacks + def on_GET(self, request): + client_redirect_url = request.args["redirectUrl"][0] + http_client = self.hs.get_simple_http_client() + uri = self.cas_server_url + "/proxyValidate" + args = { + "ticket": request.args["ticket"], + "service": self.cas_service_url + } + body = yield http_client.get_raw(uri, args) + result = yield self.handle_cas_response(request, body, client_redirect_url) + defer.returnValue(result) + + @defer.inlineCallbacks + def handle_cas_response(self, request, cas_response_body, client_redirect_url): + user, attributes = self.parse_cas_response(cas_response_body) + + for required_attribute, required_value in self.cas_required_attributes.items(): + # If required attribute was not in CAS Response - Forbidden + if required_attribute not in attributes: + raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) + + # Also need to check value + if required_value is not None: + actual_value = attributes[required_attribute] + # If required attribute value does not match expected - Forbidden + if required_value != actual_value: + raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) + + user_id = UserID.create(user, self.hs.hostname).to_string() + auth_handler = self.handlers.auth_handler + user_exists = yield auth_handler.does_user_exist(user_id) + if not user_exists: + user_id, _ = ( + yield self.handlers.registration_handler.register(localpart=user) + ) + + login_token = auth_handler.generate_short_term_login_token(user_id) + redirect_url = self.add_login_token_to_redirect_url(client_redirect_url, + login_token) + request.redirect(redirect_url) + request.finish() + + def add_login_token_to_redirect_url(self, url, token): + url_parts = list(urlparse.urlparse(url)) + query = dict(urlparse.parse_qsl(url_parts[4])) + query.update({"loginToken": token}) + url_parts[4] = urllib.urlencode(query) + return urlparse.urlunparse(url_parts) + + def parse_cas_response(self, cas_response_body): + root = ET.fromstring(cas_response_body) + if not root.tag.endswith("serviceResponse"): + raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) + if not root[0].tag.endswith("authenticationSuccess"): + raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED) + for child in root[0]: + if child.tag.endswith("user"): + user = child.text + if child.tag.endswith("attributes"): + attributes = {} + for attribute in child: + # ElementTree library expands the namespace in attribute tags + # to the full URL of the namespace. + # See (https://docs.python.org/2/library/xml.etree.elementtree.html) + # We don't care about namespace here and it will always be encased in + # curly braces, so we remove them. + if "}" in attribute.tag: + attributes[attribute.tag.split("}")[1]] = attribute.text + else: + attributes[attribute.tag] = attribute.text + if user is None or attributes is None: + raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) + + return (user, attributes) + + def _parse_json(request): try: content = json.loads(request.content.read()) @@ -269,5 +416,7 @@ def register_servlets(hs, http_server): if hs.config.saml2_enabled: SAML2RestServlet(hs).register(http_server) if hs.config.cas_enabled: + CasRedirectServlet(hs).register(http_server) + CasTicketServlet(hs).register(http_server) CasRestServlet(hs).register(http_server) # TODO PasswordResetRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 139dac1cc3..6952d269ec 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -448,7 +448,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): def register(self, http_server): # /rooms/$roomid/[invite|join|leave] PATTERN = ("/rooms/(?P<room_id>[^/]*)/" - "(?P<membership_action>join|invite|leave|ban|kick)") + "(?P<membership_action>join|invite|leave|ban|kick|forget)") register_txn_path(self, PATTERN, http_server) @defer.inlineCallbacks @@ -458,6 +458,8 @@ class RoomMembershipRestServlet(ClientV1RestServlet): allow_guest=True ) + effective_membership_action = membership_action + if is_guest and membership_action not in {Membership.JOIN, Membership.LEAVE}: raise AuthError(403, "Guest access not allowed") @@ -488,11 +490,13 @@ class RoomMembershipRestServlet(ClientV1RestServlet): UserID.from_string(state_key) if membership_action == "kick": - membership_action = "leave" + effective_membership_action = "leave" + elif membership_action == "forget": + effective_membership_action = "leave" msg_handler = self.handlers.message_handler - content = {"membership": unicode(membership_action)} + content = {"membership": unicode(effective_membership_action)} if is_guest: content["kind"] = "guest" @@ -509,6 +513,9 @@ class RoomMembershipRestServlet(ClientV1RestServlet): is_guest=is_guest, ) + if membership_action == "forget": + self.handlers.room_member_handler.forget(user, room_id) + defer.returnValue((200, {})) def _has_3pid_invite_keys(self, content): diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index efd8281558..775f49885b 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -22,7 +22,7 @@ from synapse.handlers.sync import SyncConfig from synapse.types import StreamToken from synapse.events import FrozenEvent from synapse.events.utils import ( - serialize_event, format_event_for_client_v2_without_event_id, + serialize_event, format_event_for_client_v2_without_room_id, ) from synapse.api.filtering import FilterCollection from ._base import client_v2_pattern @@ -148,9 +148,9 @@ class SyncRestServlet(RestServlet): sync_result.presence, filter, time_now ), "rooms": { - "joined": joined, - "invited": invited, - "archived": archived, + "join": joined, + "invite": invited, + "leave": archived, }, "next_batch": sync_result.next_batch.to_string(), } @@ -207,7 +207,7 @@ class SyncRestServlet(RestServlet): for room in rooms: invite = serialize_event( room.invite, time_now, token_id=token_id, - event_format=format_event_for_client_v2_without_event_id, + event_format=format_event_for_client_v2_without_room_id, ) invited_state = invite.get("unsigned", {}).pop("invite_room_state", []) invited_state.append(invite) @@ -256,7 +256,13 @@ class SyncRestServlet(RestServlet): :return: the room, encoded in our response format :rtype: dict[str, object] """ - event_map = {} + def serialize(event): + # TODO(mjark): Respect formatting requirements in the filter. + return serialize_event( + event, time_now, token_id=token_id, + event_format=format_event_for_client_v2_without_room_id, + ) + state_dict = room.state timeline_events = filter.filter_room_timeline(room.timeline.events) @@ -264,37 +270,22 @@ class SyncRestServlet(RestServlet): state_dict, timeline_events) state_events = filter.filter_room_state(state_dict.values()) - state_event_ids = [] - for event in state_events: - # TODO(mjark): Respect formatting requirements in the filter. - event_map[event.event_id] = serialize_event( - event, time_now, token_id=token_id, - event_format=format_event_for_client_v2_without_event_id, - ) - state_event_ids.append(event.event_id) - timeline_event_ids = [] - for event in timeline_events: - # TODO(mjark): Respect formatting requirements in the filter. - event_map[event.event_id] = serialize_event( - event, time_now, token_id=token_id, - event_format=format_event_for_client_v2_without_event_id, - ) - timeline_event_ids.append(event.event_id) + serialized_state = [serialize(e) for e in state_events] + serialized_timeline = [serialize(e) for e in timeline_events] - private_user_data = filter.filter_room_private_user_data( - room.private_user_data + account_data = filter.filter_room_account_data( + room.account_data ) result = { - "event_map": event_map, "timeline": { - "events": timeline_event_ids, + "events": serialized_timeline, "prev_batch": room.timeline.prev_batch.to_string(), "limited": room.timeline.limited, }, - "state": {"events": state_event_ids}, - "private_user_data": {"events": private_user_data}, + "state": {"events": serialized_state}, + "account_data": {"events": account_data}, } if joined: diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index 35482ae6a6..ba7223be11 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -81,7 +81,7 @@ class TagServlet(RestServlet): max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body) yield self.notifier.on_new_event( - "private_user_data_key", max_id, users=[user_id] + "account_data_key", max_id, users=[user_id] ) defer.returnValue((200, {})) @@ -95,7 +95,7 @@ class TagServlet(RestServlet): max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag) yield self.notifier.on_new_event( - "private_user_data_key", max_id, users=[user_id] + "account_data_key", max_id, users=[user_id] ) defer.returnValue((200, {})) diff --git a/synapse/static/client/login/js/login.js b/synapse/static/client/login/js/login.js index ab8b4d44ea..bfb7386035 100644 --- a/synapse/static/client/login/js/login.js +++ b/synapse/static/client/login/js/login.js @@ -17,12 +17,11 @@ var submitPassword = function(user, pwd) { }).error(errorFunc); }; -var submitCas = function(ticket, service) { - console.log("Logging in with cas..."); +var submitToken = function(loginToken) { + console.log("Logging in with login token..."); var data = { - type: "m.login.cas", - ticket: ticket, - service: service, + type: "m.login.token", + token: loginToken }; $.post(matrixLogin.endpoint, JSON.stringify(data), function(response) { show_login(); @@ -41,23 +40,10 @@ var errorFunc = function(err) { } }; -var getCasURL = function(cb) { - $.get(matrixLogin.endpoint + "/cas", function(response) { - var cas_url = response.serverUrl; - - cb(cas_url); - }).error(errorFunc); -}; - - var gotoCas = function() { - getCasURL(function(cas_url) { - var this_page = window.location.origin + window.location.pathname; - - var redirect_url = cas_url + "/login?service=" + encodeURIComponent(this_page); - - window.location.replace(redirect_url); - }); + var this_page = window.location.origin + window.location.pathname; + var redirect_url = matrixLogin.endpoint + "/cas/redirect?redirectUrl=" + encodeURIComponent(this_page); + window.location.replace(redirect_url); } var setFeedbackString = function(text) { @@ -111,7 +97,7 @@ var fetch_info = function(cb) { matrixLogin.onLoad = function() { fetch_info(function() { - if (!try_cas()) { + if (!try_token()) { show_login(); } }); @@ -148,20 +134,20 @@ var parseQsFromUrl = function(query) { return result; }; -var try_cas = function() { +var try_token = function() { var pos = window.location.href.indexOf("?"); if (pos == -1) { return false; } var qs = parseQsFromUrl(window.location.href.substr(pos+1)); - var ticket = qs.ticket; + var loginToken = qs.loginToken; - if (!ticket) { + if (!loginToken) { return false; } - submitCas(ticket, location.origin); + submitToken(loginToken); return true; }; diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 1a74d6e360..9800fd4203 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 25 +SCHEMA_VERSION = 26 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index ae1ad56d9a..d32ce1ab1e 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -160,7 +160,7 @@ class RoomMemberStore(SQLBaseStore): def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id, membership_list): - where_clause = "user_id = ? AND (%s)" % ( + where_clause = "user_id = ? AND (%s) AND NOT forgotten" % ( " OR ".join(["membership = ?" for _ in membership_list]), ) @@ -269,3 +269,67 @@ class RoomMemberStore(SQLBaseStore): ret = len(room_id_lists.pop(0).intersection(*room_id_lists)) > 0 defer.returnValue(ret) + + def forget(self, user_id, room_id): + """Indicate that user_id wishes to discard history for room_id.""" + def f(txn): + sql = ( + "UPDATE" + " room_memberships" + " SET" + " forgotten = 1" + " WHERE" + " user_id = ?" + " AND" + " room_id = ?" + ) + txn.execute(sql, (user_id, room_id)) + self.runInteraction("forget_membership", f) + + @defer.inlineCallbacks + def did_forget(self, user_id, room_id): + """Returns whether user_id has elected to discard history for room_id. + + Returns False if they have since re-joined.""" + def f(txn): + sql = ( + "SELECT" + " COUNT(*)" + " FROM" + " room_memberships" + " WHERE" + " user_id = ?" + " AND" + " room_id = ?" + " AND" + " forgotten = 0" + ) + txn.execute(sql, (user_id, room_id)) + rows = txn.fetchall() + return rows[0][0] + count = yield self.runInteraction("did_forget_membership", f) + defer.returnValue(count == 0) + + @defer.inlineCallbacks + def was_forgotten_at(self, user_id, room_id, event_id): + """Returns whether user_id has elected to discard history for room_id at event_id. + + event_id must be a membership event.""" + def f(txn): + sql = ( + "SELECT" + " forgotten" + " FROM" + " room_memberships" + " WHERE" + " user_id = ?" + " AND" + " room_id = ?" + " AND" + " event_id = ?" + ) + txn.execute(sql, (user_id, room_id, event_id)) + rows = txn.fetchall() + return rows[0][0] + forgot = yield self.runInteraction("did_forget_membership_at", f) + defer.returnValue(forgot == 1) diff --git a/synapse/storage/schema/delta/26/account_data.sql b/synapse/storage/schema/delta/26/account_data.sql new file mode 100644 index 0000000000..3198a0d29c --- /dev/null +++ b/synapse/storage/schema/delta/26/account_data.sql @@ -0,0 +1,17 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +ALTER TABLE private_user_data_max_stream_id RENAME TO account_data_max_stream_id; diff --git a/synapse/storage/schema/delta/26/forgotten_memberships.sql b/synapse/storage/schema/delta/26/forgotten_memberships.sql new file mode 100644 index 0000000000..df55b9c6f6 --- /dev/null +++ b/synapse/storage/schema/delta/26/forgotten_memberships.sql @@ -0,0 +1,24 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Keeps track of what rooms users have left and don't want to be able to + * access again. + * + * If all users on this server have left a room, we can delete the room + * entirely. + */ + + ALTER TABLE room_memberships ADD COLUMN forgotten INTEGER(1) DEFAULT 0; diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index bf695b7800..f6d826cc59 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -28,17 +28,17 @@ class TagsStore(SQLBaseStore): def __init__(self, hs): super(TagsStore, self).__init__(hs) - self._private_user_data_id_gen = StreamIdGenerator( - "private_user_data_max_stream_id", "stream_id" + self._account_data_id_gen = StreamIdGenerator( + "account_data_max_stream_id", "stream_id" ) - def get_max_private_user_data_stream_id(self): + def get_max_account_data_stream_id(self): """Get the current max stream id for the private user data stream Returns: A deferred int. """ - return self._private_user_data_id_gen.get_max_token(self) + return self._account_data_id_gen.get_max_token(self) @cached() def get_tags_for_user(self, user_id): @@ -144,12 +144,12 @@ class TagsStore(SQLBaseStore): ) self._update_revision_txn(txn, user_id, room_id, next_id) - with (yield self._private_user_data_id_gen.get_next(self)) as next_id: + with (yield self._account_data_id_gen.get_next(self)) as next_id: yield self.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) - result = yield self._private_user_data_id_gen.get_max_token(self) + result = yield self._account_data_id_gen.get_max_token(self) defer.returnValue(result) @defer.inlineCallbacks @@ -166,12 +166,12 @@ class TagsStore(SQLBaseStore): txn.execute(sql, (user_id, room_id, tag)) self._update_revision_txn(txn, user_id, room_id, next_id) - with (yield self._private_user_data_id_gen.get_next(self)) as next_id: + with (yield self._account_data_id_gen.get_next(self)) as next_id: yield self.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) - result = yield self._private_user_data_id_gen.get_max_token(self) + result = yield self._account_data_id_gen.get_max_token(self) defer.returnValue(result) def _update_revision_txn(self, txn, user_id, room_id, next_id): @@ -185,7 +185,7 @@ class TagsStore(SQLBaseStore): """ update_max_id_sql = ( - "UPDATE private_user_data_max_stream_id" + "UPDATE account_data_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?" ) diff --git a/synapse/streams/events.py b/synapse/streams/events.py index f0d68b5bf2..cfa7d30fa5 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -21,7 +21,7 @@ from synapse.handlers.presence import PresenceEventSource from synapse.handlers.room import RoomEventSource from synapse.handlers.typing import TypingNotificationEventSource from synapse.handlers.receipts import ReceiptEventSource -from synapse.handlers.private_user_data import PrivateUserDataEventSource +from synapse.handlers.account_data import AccountDataEventSource class EventSources(object): @@ -30,7 +30,7 @@ class EventSources(object): "presence": PresenceEventSource, "typing": TypingNotificationEventSource, "receipt": ReceiptEventSource, - "private_user_data": PrivateUserDataEventSource, + "account_data": AccountDataEventSource, } def __init__(self, hs): @@ -54,8 +54,8 @@ class EventSources(object): receipt_key=( yield self.sources["receipt"].get_current_key() ), - private_user_data_key=( - yield self.sources["private_user_data"].get_current_key() + account_data_key=( + yield self.sources["account_data"].get_current_key() ), ) defer.returnValue(token) diff --git a/synapse/types.py b/synapse/types.py index 28344d8b36..af1d76ab46 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -103,7 +103,7 @@ class StreamToken( "presence_key", "typing_key", "receipt_key", - "private_user_data_key", + "account_data_key", )) ): _SEPARATOR = "_" @@ -138,7 +138,7 @@ class StreamToken( or (int(other.presence_key) < int(self.presence_key)) or (int(other.typing_key) < int(self.typing_key)) or (int(other.receipt_key) < int(self.receipt_key)) - or (int(other.private_user_data_key) < int(self.private_user_data_key)) + or (int(other.account_data_key) < int(self.account_data_key)) ) def copy_and_advance(self, key, new_value): |