diff options
Diffstat (limited to 'synapse')
29 files changed, 765 insertions, 148 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index 498ded38c0..da8ef90a77 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.18.7" +__version__ = "0.19.1" diff --git a/synapse/api/constants.py b/synapse/api/constants.py index a8123cddcb..ca23c9c460 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -43,9 +43,6 @@ class JoinRules(object): class LoginType(object): PASSWORD = u"m.login.password" - OAUTH = u"m.login.oauth2" - EMAIL_CODE = u"m.login.email.code" - EMAIL_URL = u"m.login.email.url" EMAIL_IDENTITY = u"m.login.email.identity" RECAPTCHA = u"m.login.recaptcha" DUMMY = u"m.login.dummy" diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 3c58d2de17..e081840a83 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -95,7 +95,7 @@ class TlsConfig(Config): # make HTTPS requests to this server will check that the TLS # certificates returned by this server match one of the fingerprints. # - # Synapse automatically adds its the fingerprint of its own certificate + # Synapse automatically adds the fingerprint of its own certificate # to the list. So if federation traffic is handle directly by synapse # then no modification to the list is required. # diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index cb106c6a1b..bb3d9258a6 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -504,7 +504,7 @@ class TransactionQueue(object): code = e.code response = e.response - if e.code == 429 or 500 <= e.code: + if e.code in (401, 404, 429) or 500 <= e.code: logger.info( "TX [%s] {%s} got %d response", destination, txn_id, code diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 084e33ca6a..f36b358b45 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -19,7 +19,6 @@ from ._base import BaseHandler import logging - logger = logging.getLogger(__name__) @@ -54,3 +53,46 @@ class AdminHandler(BaseHandler): } defer.returnValue(ret) + + @defer.inlineCallbacks + def get_users(self): + """Function to reterive a list of users in users table. + + Args: + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + ret = yield self.store.get_users() + + defer.returnValue(ret) + + @defer.inlineCallbacks + def get_users_paginate(self, order, start, limit): + """Function to reterive a paginated list of users from + users list. This will return a json object, which contains + list of users and the total number of users in users table. + + Args: + order (str): column name to order the select by this column + start (int): start number to begin the query from + limit (int): number of rows to reterive + Returns: + defer.Deferred: resolves to json object {list[dict[str, Any]], count} + """ + ret = yield self.store.get_users_paginate(order, start, limit) + + defer.returnValue(ret) + + @defer.inlineCallbacks + def search_users(self, term): + """Function to search users list for one or more users with + the matched term. + + Args: + term (str): search term + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + ret = yield self.store.search_users(term) + + defer.returnValue(ret) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 221d7ea7a2..fffba34383 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -65,6 +65,7 @@ class AuthHandler(BaseHandler): self.hs = hs # FIXME better possibility to access registrationHandler later? self.device_handler = hs.get_device_handler() + self.macaroon_gen = hs.get_macaroon_generator() @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): @@ -529,37 +530,11 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def issue_access_token(self, user_id, device_id=None): - access_token = self.generate_access_token(user_id) + access_token = self.macaroon_gen.generate_access_token(user_id) yield self.store.add_access_token_to_user(user_id, access_token, device_id) defer.returnValue(access_token) - def generate_access_token(self, user_id, extra_caveats=None): - extra_caveats = extra_caveats or [] - macaroon = self._generate_base_macaroon(user_id) - macaroon.add_first_party_caveat("type = access") - # Include a nonce, to make sure that each login gets a different - # access token. - macaroon.add_first_party_caveat("nonce = %s" % ( - stringutils.random_string_with_symbols(16), - )) - for caveat in extra_caveats: - macaroon.add_first_party_caveat(caveat) - return macaroon.serialize() - - def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)): - macaroon = self._generate_base_macaroon(user_id) - macaroon.add_first_party_caveat("type = login") - now = self.hs.get_clock().time_msec() - expiry = now + duration_in_ms - macaroon.add_first_party_caveat("time < %d" % (expiry,)) - return macaroon.serialize() - - def generate_delete_pusher_token(self, user_id): - macaroon = self._generate_base_macaroon(user_id) - macaroon.add_first_party_caveat("type = delete_pusher") - return macaroon.serialize() - def validate_short_term_login_token_and_get_user_id(self, login_token): auth_api = self.hs.get_auth() try: @@ -570,15 +545,6 @@ class AuthHandler(BaseHandler): except Exception: raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) - def _generate_base_macaroon(self, user_id): - macaroon = pymacaroons.Macaroon( - location=self.hs.config.server_name, - identifier="key", - key=self.hs.config.macaroon_secret_key) - macaroon.add_first_party_caveat("gen = 1") - macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) - return macaroon - @defer.inlineCallbacks def set_password(self, user_id, newpassword, requester=None): password_hash = self.hash(newpassword) @@ -673,6 +639,48 @@ class AuthHandler(BaseHandler): return False +class MacaroonGeneartor(object): + def __init__(self, hs): + self.clock = hs.get_clock() + self.server_name = hs.config.server_name + self.macaroon_secret_key = hs.config.macaroon_secret_key + + def generate_access_token(self, user_id, extra_caveats=None): + extra_caveats = extra_caveats or [] + macaroon = self._generate_base_macaroon(user_id) + macaroon.add_first_party_caveat("type = access") + # Include a nonce, to make sure that each login gets a different + # access token. + macaroon.add_first_party_caveat("nonce = %s" % ( + stringutils.random_string_with_symbols(16), + )) + for caveat in extra_caveats: + macaroon.add_first_party_caveat(caveat) + return macaroon.serialize() + + def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)): + macaroon = self._generate_base_macaroon(user_id) + macaroon.add_first_party_caveat("type = login") + now = self.clock.time_msec() + expiry = now + duration_in_ms + macaroon.add_first_party_caveat("time < %d" % (expiry,)) + return macaroon.serialize() + + def generate_delete_pusher_token(self, user_id): + macaroon = self._generate_base_macaroon(user_id) + macaroon.add_first_party_caveat("type = delete_pusher") + return macaroon.serialize() + + def _generate_base_macaroon(self, user_id): + macaroon = pymacaroons.Macaroon( + location=self.server_name, + identifier="key", + key=self.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) + return macaroon + + class _AccountHandler(object): """A proxy object that gets passed to password auth providers so they can register new users etc if necessary. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 7245d14fab..8cb47ac417 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -14,9 +14,11 @@ # limitations under the License. from synapse.api import errors +from synapse.api.constants import EventTypes from synapse.util import stringutils from synapse.util.async import Linearizer -from synapse.types import get_domain_from_id +from synapse.util.metrics import measure_func +from synapse.types import get_domain_from_id, RoomStreamToken from twisted.internet import defer from ._base import BaseHandler @@ -192,25 +194,28 @@ class DeviceHandler(BaseHandler): else: raise + @measure_func("notify_device_update") @defer.inlineCallbacks def notify_device_update(self, user_id, device_ids): """Notify that a user's device(s) has changed. Pokes the notifier, and remote servers if the user is local. """ - rooms = yield self.store.get_rooms_for_user(user_id) - room_ids = [r.room_id for r in rooms] + users_who_share_room = yield self.store.get_users_who_share_room_with_user( + user_id + ) hosts = set() if self.hs.is_mine_id(user_id): - for room_id in room_ids: - users = yield self.store.get_users_in_room(room_id) - hosts.update(get_domain_from_id(u) for u in users) + hosts.update(get_domain_from_id(u) for u in users_who_share_room) hosts.discard(self.server_name) position = yield self.store.add_device_change_to_streams( user_id, device_ids, list(hosts) ) + rooms = yield self.store.get_rooms_for_user(user_id) + room_ids = [r.room_id for r in rooms] + yield self.notifier.on_new_event( "device_list_key", position, rooms=room_ids, ) @@ -220,6 +225,61 @@ class DeviceHandler(BaseHandler): for host in hosts: self.federation_sender.send_device_messages(host) + @measure_func("device.get_user_ids_changed") + @defer.inlineCallbacks + def get_user_ids_changed(self, user_id, from_token): + """Get list of users that have had the devices updated, or have newly + joined a room, that `user_id` may be interested in. + + Args: + user_id (str) + from_token (StreamToken) + """ + rooms = yield self.store.get_rooms_for_user(user_id) + room_ids = set(r.room_id for r in rooms) + + # First we check if any devices have changed + changed = yield self.store.get_user_whose_devices_changed( + from_token.device_list_key + ) + + # Then work out if any users have since joined + rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key) + + possibly_changed = set(changed) + for room_id in rooms_changed: + # Fetch the current state at the time. + stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key) + + try: + event_ids = yield self.store.get_forward_extremeties_for_room( + room_id, stream_ordering=stream_ordering + ) + prev_state_ids = yield self.store.get_state_ids_for_events(event_ids) + except: + prev_state_ids = {} + + current_state_ids = yield self.state.get_current_state_ids(room_id) + + # If there has been any change in membership, include them in the + # possibly changed list. We'll check if they are joined below, + # and we're not toooo worried about spuriously adding users. + for key, event_id in current_state_ids.iteritems(): + etype, state_key = key + if etype == EventTypes.Member: + prev_event_id = prev_state_ids.get(key, None) + if not prev_event_id or prev_event_id != event_id: + possibly_changed.add(state_key) + + users_who_share_room = yield self.store.get_users_who_share_room_with_user( + user_id + ) + + # Take the intersection of the users whose devices may have changed + # and those that actually still share a room with the user + defer.returnValue(users_who_share_room & possibly_changed) + + @measure_func("_incoming_device_list_update") @defer.inlineCallbacks def _incoming_device_list_update(self, origin, edu_content): user_id = edu_content["user_id"] diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 9982ae0fed..fdfce2a88c 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1011,7 +1011,7 @@ class PresenceEventSource(object): @defer.inlineCallbacks @log_function def get_new_events(self, user, from_key, room_ids=None, include_offline=True, - **kwargs): + explicit_room_id=None, **kwargs): # The process for getting presence events are: # 1. Get the rooms the user is in. # 2. Get the list of user in the rooms. @@ -1028,22 +1028,24 @@ class PresenceEventSource(object): user_id = user.to_string() if from_key is not None: from_key = int(from_key) - room_ids = room_ids or [] presence = self.get_presence_handler() stream_change_cache = self.store.presence_stream_cache - if not room_ids: - rooms = yield self.store.get_rooms_for_user(user_id) - room_ids = set(e.room_id for e in rooms) - else: - room_ids = set(room_ids) - max_token = self.store.get_current_presence_token() plist = yield self.store.get_presence_list_accepted(user.localpart) - friends = set(row["observed_user_id"] for row in plist) - friends.add(user_id) # So that we receive our own presence + users_interested_in = set(row["observed_user_id"] for row in plist) + users_interested_in.add(user_id) # So that we receive our own presence + + users_who_share_room = yield self.store.get_users_who_share_room_with_user( + user_id + ) + users_interested_in.update(users_who_share_room) + + if explicit_room_id: + user_ids = yield self.store.get_users_in_room(explicit_room_id) + users_interested_in.update(user_ids) user_ids_changed = set() changed = None @@ -1055,35 +1057,19 @@ class PresenceEventSource(object): # work out if we share a room or they're in our presence list get_updates_counter.inc("stream") for other_user_id in changed: - if other_user_id in friends: + if other_user_id in users_interested_in: user_ids_changed.add(other_user_id) - continue - other_rooms = yield self.store.get_rooms_for_user(other_user_id) - if room_ids.intersection(e.room_id for e in other_rooms): - user_ids_changed.add(other_user_id) - continue else: # Too many possible updates. Find all users we can see and check # if any of them have changed. get_updates_counter.inc("full") - user_ids_to_check = set() - for room_id in room_ids: - users = yield self.store.get_users_in_room(room_id) - user_ids_to_check.update(users) - - user_ids_to_check.update(friends) - - # Always include yourself. Only really matters for when the user is - # not in any rooms, but still. - user_ids_to_check.add(user_id) - if from_key: user_ids_changed = stream_change_cache.get_entities_changed( - user_ids_to_check, from_key, + users_interested_in, from_key, ) else: - user_ids_changed = user_ids_to_check + user_ids_changed = users_interested_in updates = yield presence.current_state_for_users(user_ids_changed) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 286f0cef0a..03c6a85fc6 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -40,6 +40,8 @@ class RegistrationHandler(BaseHandler): self._next_generated_user_id = None + self.macaroon_gen = hs.get_macaroon_generator() + @defer.inlineCallbacks def check_username(self, localpart, guest_access_token=None, assigned_user_id=None): @@ -143,7 +145,7 @@ class RegistrationHandler(BaseHandler): token = None if generate_token: - token = self.auth_handler().generate_access_token(user_id) + token = self.macaroon_gen.generate_access_token(user_id) yield self.store.register( user_id=user_id, token=token, @@ -167,7 +169,7 @@ class RegistrationHandler(BaseHandler): user_id = user.to_string() yield self.check_user_id_not_appservice_exclusive(user_id) if generate_token: - token = self.auth_handler().generate_access_token(user_id) + token = self.macaroon_gen.generate_access_token(user_id) try: yield self.store.register( user_id=user_id, @@ -254,7 +256,7 @@ class RegistrationHandler(BaseHandler): user_id = user.to_string() yield self.check_user_id_not_appservice_exclusive(user_id) - token = self.auth_handler().generate_access_token(user_id) + token = self.macaroon_gen.generate_access_token(user_id) try: yield self.store.register( user_id=user_id, @@ -399,7 +401,7 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - token = self.auth_handler().generate_access_token(user_id) + token = self.macaroon_gen.generate_access_token(user_id) if need_register: yield self.store.register( diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 5f18007e90..7e7671c9a2 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -437,6 +437,7 @@ class RoomEventSource(object): limit, room_ids, is_guest, + explicit_room_id=None, ): # We just ignore the key for now. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 9199f20817..d7dcd1ce5b 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -16,7 +16,7 @@ from synapse.api.constants import Membership, EventTypes from synapse.util.async import concurrently_execute from synapse.util.logcontext import LoggingContext -from synapse.util.metrics import Measure +from synapse.util.metrics import Measure, measure_func from synapse.util.caches.response_cache import ResponseCache from synapse.push.clientformat import format_push_rules_for_user from synapse.visibility import filter_events_for_client @@ -130,7 +130,8 @@ class SyncResult(collections.namedtuple("SyncResult", [ self.invited or self.archived or self.account_data or - self.to_device + self.to_device or + self.device_lists ) @@ -560,6 +561,7 @@ class SyncHandler(object): next_batch=sync_result_builder.now_token, )) + @measure_func("_generate_sync_entry_for_device_list") @defer.inlineCallbacks def _generate_sync_entry_for_device_list(self, sync_result_builder): user_id = sync_result_builder.sync_config.user.to_string() diff --git a/synapse/notifier.py b/synapse/notifier.py index acbd4bb5ae..8051a7a842 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -378,6 +378,7 @@ class Notifier(object): limit=limit, is_guest=is_peeking, room_ids=room_ids, + explicit_room_id=explicit_room_id, ) if name == "room": diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index ce2d31fb98..62d794f22b 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -81,7 +81,7 @@ class Mailer(object): def __init__(self, hs, app_name): self.hs = hs self.store = self.hs.get_datastore() - self.auth_handler = self.hs.get_auth_handler() + self.macaroon_gen = self.hs.get_macaroon_generator() self.state_handler = self.hs.get_state_handler() loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir) self.app_name = app_name @@ -466,7 +466,7 @@ class Mailer(object): def make_unsubscribe_link(self, user_id, app_id, email_address): params = { - "access_token": self.auth_handler.generate_delete_pusher_token(user_id), + "access_token": self.macaroon_gen.generate_delete_pusher_token(user_id), "app_id": app_id, "pushkey": email_address, } diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index a30e647474..d8eb14592b 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -299,9 +299,6 @@ class ReplicationResource(Resource): "backward_ex_outliers", res.backward_ex_outliers, ("position", "event_id", "state_group"), ) - writer.write_header_and_rows( - "state_resets", res.state_resets, ("position",), - ) @defer.inlineCallbacks def presence(self, writer, current_token, request_streams): diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index b3f3bf7488..d72ff6055c 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -73,6 +73,9 @@ class SlavedEventStore(BaseSlavedStore): # to reach inside the __dict__ to extract them. get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"] get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"] + get_users_who_share_room_with_user = ( + RoomMemberStore.__dict__["get_users_who_share_room_with_user"] + ) get_latest_event_ids_in_room = EventFederationStore.__dict__[ "get_latest_event_ids_in_room" ] @@ -192,10 +195,6 @@ class SlavedEventStore(BaseSlavedStore): return result def process_replication(self, result): - state_resets = set( - r[0] for r in result.get("state_resets", {"rows": []})["rows"] - ) - stream = result.get("events") if stream: self._stream_id_gen.advance(int(stream["position"])) @@ -205,7 +204,7 @@ class SlavedEventStore(BaseSlavedStore): for row in stream["rows"]: self._process_replication_row( - row, backfilled=False, state_resets=state_resets + row, backfilled=False, ) stream = result.get("backfill") @@ -213,7 +212,7 @@ class SlavedEventStore(BaseSlavedStore): self._backfill_id_gen.advance(-int(stream["position"])) for row in stream["rows"]: self._process_replication_row( - row, backfilled=True, state_resets=state_resets + row, backfilled=True, ) stream = result.get("forward_ex_outliers") @@ -232,20 +231,15 @@ class SlavedEventStore(BaseSlavedStore): return super(SlavedEventStore, self).process_replication(result) - def _process_replication_row(self, row, backfilled, state_resets): - position = row[0] + def _process_replication_row(self, row, backfilled): internal = json.loads(row[1]) event_json = json.loads(row[2]) event = FrozenEvent(event_json, internal_metadata_dict=internal) self.invalidate_caches_for_event( - event, backfilled, reset_state=position in state_resets + event, backfilled, ) - def invalidate_caches_for_event(self, event, backfilled, reset_state): - if reset_state: - self.get_rooms_for_user.invalidate_all() - self.get_users_in_room.invalidate((event.room_id,)) - + def invalidate_caches_for_event(self, event, backfilled): self._invalidate_get_event_cache(event.event_id) self.get_latest_event_ids_in_room.invalidate((event.room_id,)) @@ -267,8 +261,6 @@ class SlavedEventStore(BaseSlavedStore): self._invalidate_get_event_cache(event.redacts) if event.type == EventTypes.Member: - self.get_rooms_for_user.invalidate((event.state_key,)) - self.get_users_in_room.invalidate((event.room_id,)) self._membership_stream_cache.entity_has_changed( event.state_key, event.internal_metadata.stream_ordering ) diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index af21661d7c..29fcd72375 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -17,6 +17,7 @@ from twisted.internet import defer from synapse.api.errors import AuthError, SynapseError from synapse.types import UserID +from synapse.http.servlet import parse_json_object_from_request from .base import ClientV1RestServlet, client_path_patterns @@ -25,6 +26,34 @@ import logging logger = logging.getLogger(__name__) +class UsersRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns("/admin/users/(?P<user_id>[^/]*)") + + def __init__(self, hs): + super(UsersRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + + @defer.inlineCallbacks + def on_GET(self, request, user_id): + target_user = UserID.from_string(user_id) + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + + if not is_admin: + raise AuthError(403, "You are not a server admin") + + # To allow all users to get the users list + # if not is_admin and target_user != auth_user: + # raise AuthError(403, "You are not a server admin") + + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only users a local user") + + ret = yield self.handlers.admin_handler.get_users() + + defer.returnValue((200, ret)) + + class WhoisRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)") @@ -128,8 +157,199 @@ class DeactivateAccountRestServlet(ClientV1RestServlet): defer.returnValue((200, {})) +class ResetPasswordRestServlet(ClientV1RestServlet): + """Post request to allow an administrator reset password for a user. + This need a user have a administrator access in Synapse. + Example: + http://localhost:8008/_matrix/client/api/v1/admin/reset_password/ + @user:to_reset_password?access_token=admin_access_token + JsonBodyToSend: + { + "new_password": "secret" + } + Returns: + 200 OK with empty object if success otherwise an error. + """ + PATTERNS = client_path_patterns("/admin/reset_password/(?P<target_user_id>[^/]*)") + + def __init__(self, hs): + self.store = hs.get_datastore() + super(ResetPasswordRestServlet, self).__init__(hs) + self.hs = hs + self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() + + @defer.inlineCallbacks + def on_POST(self, request, target_user_id): + """Post request to allow an administrator reset password for a user. + This need a user have a administrator access in Synapse. + """ + UserID.from_string(target_user_id) + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + + if not is_admin: + raise AuthError(403, "You are not a server admin") + + params = parse_json_object_from_request(request) + new_password = params['new_password'] + if not new_password: + raise SynapseError(400, "Missing 'new_password' arg") + + logger.info("new_password: %r", new_password) + + yield self.auth_handler.set_password( + target_user_id, new_password, requester + ) + defer.returnValue((200, {})) + + +class GetUsersPaginatedRestServlet(ClientV1RestServlet): + """Get request to get specific number of users from Synapse. + This need a user have a administrator access in Synapse. + Example: + http://localhost:8008/_matrix/client/api/v1/admin/users_paginate/ + @admin:user?access_token=admin_access_token&start=0&limit=10 + Returns: + 200 OK with json object {list[dict[str, Any]], count} or empty object. + """ + PATTERNS = client_path_patterns("/admin/users_paginate/(?P<target_user_id>[^/]*)") + + def __init__(self, hs): + self.store = hs.get_datastore() + super(GetUsersPaginatedRestServlet, self).__init__(hs) + self.hs = hs + self.auth = hs.get_auth() + self.handlers = hs.get_handlers() + + @defer.inlineCallbacks + def on_GET(self, request, target_user_id): + """Get request to get specific number of users from Synapse. + This need a user have a administrator access in Synapse. + """ + target_user = UserID.from_string(target_user_id) + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + + if not is_admin: + raise AuthError(403, "You are not a server admin") + + # To allow all users to get the users list + # if not is_admin and target_user != auth_user: + # raise AuthError(403, "You are not a server admin") + + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only users a local user") + + order = "name" # order by name in user table + start = request.args.get("start")[0] + limit = request.args.get("limit")[0] + if not limit: + raise SynapseError(400, "Missing 'limit' arg") + if not start: + raise SynapseError(400, "Missing 'start' arg") + logger.info("limit: %s, start: %s", limit, start) + + ret = yield self.handlers.admin_handler.get_users_paginate( + order, start, limit + ) + defer.returnValue((200, ret)) + + @defer.inlineCallbacks + def on_POST(self, request, target_user_id): + """Post request to get specific number of users from Synapse.. + This need a user have a administrator access in Synapse. + Example: + http://localhost:8008/_matrix/client/api/v1/admin/users_paginate/ + @admin:user?access_token=admin_access_token + JsonBodyToSend: + { + "start": "0", + "limit": "10 + } + Returns: + 200 OK with json object {list[dict[str, Any]], count} or empty object. + """ + UserID.from_string(target_user_id) + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + + if not is_admin: + raise AuthError(403, "You are not a server admin") + + order = "name" # order by name in user table + params = parse_json_object_from_request(request) + limit = params['limit'] + start = params['start'] + if not limit: + raise SynapseError(400, "Missing 'limit' arg") + if not start: + raise SynapseError(400, "Missing 'start' arg") + logger.info("limit: %s, start: %s", limit, start) + + ret = yield self.handlers.admin_handler.get_users_paginate( + order, start, limit + ) + defer.returnValue((200, ret)) + + +class SearchUsersRestServlet(ClientV1RestServlet): + """Get request to search user table for specific users according to + search term. + This need a user have a administrator access in Synapse. + Example: + http://localhost:8008/_matrix/client/api/v1/admin/search_users/ + @admin:user?access_token=admin_access_token&term=alice + Returns: + 200 OK with json object {list[dict[str, Any]], count} or empty object. + """ + PATTERNS = client_path_patterns("/admin/search_users/(?P<target_user_id>[^/]*)") + + def __init__(self, hs): + self.store = hs.get_datastore() + super(SearchUsersRestServlet, self).__init__(hs) + self.hs = hs + self.auth = hs.get_auth() + self.handlers = hs.get_handlers() + + @defer.inlineCallbacks + def on_GET(self, request, target_user_id): + """Get request to search user table for specific users according to + search term. + This need a user have a administrator access in Synapse. + """ + target_user = UserID.from_string(target_user_id) + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + + if not is_admin: + raise AuthError(403, "You are not a server admin") + + # To allow all users to get the users list + # if not is_admin and target_user != auth_user: + # raise AuthError(403, "You are not a server admin") + + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only users a local user") + + term = request.args.get("term")[0] + if not term: + raise SynapseError(400, "Missing 'term' arg") + + logger.info("term: %s ", term) + + ret = yield self.handlers.admin_handler.search_users( + term + ) + defer.returnValue((200, ret)) + + def register_servlets(hs, http_server): WhoisRestServlet(hs).register(http_server) PurgeMediaCacheRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server) PurgeHistoryRestServlet(hs).register(http_server) + UsersRestServlet(hs).register(http_server) + ResetPasswordRestServlet(hs).register(http_server) + GetUsersPaginatedRestServlet(hs).register(http_server) + SearchUsersRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 0c9cdff3b8..72057f1b0c 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -330,6 +330,7 @@ class CasTicketServlet(ClientV1RestServlet): self.cas_required_attributes = hs.config.cas_required_attributes self.auth_handler = hs.get_auth_handler() self.handlers = hs.get_handlers() + self.macaroon_gen = hs.get_macaroon_generator() @defer.inlineCallbacks def on_GET(self, request): @@ -368,7 +369,9 @@ class CasTicketServlet(ClientV1RestServlet): yield self.handlers.registration_handler.register(localpart=user) ) - login_token = auth_handler.generate_short_term_login_token(registered_user_id) + login_token = self.macaroon_gen.generate_short_term_login_token( + registered_user_id + ) redirect_url = self.add_login_token_to_redirect_url(client_redirect_url, login_token) request.redirect(redirect_url) diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index 355e82474b..1a5045c9ec 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -46,6 +46,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): def on_PUT(self, request, user_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) user = UserID.from_string(user_id) + is_admin = yield self.auth.is_server_admin(requester.user) content = parse_json_object_from_request(request) @@ -55,7 +56,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): defer.returnValue((400, "Unable to parse name")) yield self.handlers.profile_handler.set_displayname( - user, requester, new_name) + user, requester, new_name, is_admin) defer.returnValue((200, {})) @@ -88,6 +89,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): def on_PUT(self, request, user_id): requester = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) + is_admin = yield self.auth.is_server_admin(requester.user) content = parse_json_object_from_request(request) try: @@ -96,7 +98,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): defer.returnValue((400, "Unable to parse name")) yield self.handlers.profile_handler.set_avatar_url( - user, requester, new_name) + user, requester, new_name, is_admin) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 46789775b9..6a3cfe84f8 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -21,6 +21,8 @@ from synapse.api.errors import SynapseError from synapse.http.servlet import ( RestServlet, parse_json_object_from_request, parse_integer ) +from synapse.http.servlet import parse_string +from synapse.types import StreamToken from ._base import client_v2_patterns logger = logging.getLogger(__name__) @@ -149,6 +151,52 @@ class KeyQueryServlet(RestServlet): defer.returnValue((200, result)) +class KeyChangesServlet(RestServlet): + """Returns the list of changes of keys between two stream tokens (may return + spurious extra results, since we currently ignore the `to` param). + + GET /keys/changes?from=...&to=... + + 200 OK + { "changed": ["@foo:example.com"] } + """ + PATTERNS = client_v2_patterns( + "/keys/changes$", + releases=() + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ + super(KeyChangesServlet, self).__init__() + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + + @defer.inlineCallbacks + def on_GET(self, request): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + + from_token_string = parse_string(request, "from") + + # We want to enforce they do pass us one, but we ignore it and return + # changes after the "to" as well as before. + parse_string(request, "to") + + from_token = StreamToken.from_string(from_token_string) + + user_id = requester.user.to_string() + + changed = yield self.device_handler.get_user_ids_changed( + user_id, from_token, + ) + + defer.returnValue((200, { + "changed": list(changed), + })) + + class OneTimeKeyServlet(RestServlet): """ POST /keys/claim HTTP/1.1 @@ -192,4 +240,5 @@ class OneTimeKeyServlet(RestServlet): def register_servlets(hs, http_server): KeyUploadServlet(hs).register(http_server) KeyQueryServlet(hs).register(http_server) + KeyChangesServlet(hs).register(http_server) OneTimeKeyServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 3e7a285e10..ccca5a12d5 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -96,6 +96,7 @@ class RegisterRestServlet(RestServlet): self.registration_handler = hs.get_handlers().registration_handler self.identity_handler = hs.get_handlers().identity_handler self.device_handler = hs.get_device_handler() + self.macaroon_gen = hs.get_macaroon_generator() @defer.inlineCallbacks def on_POST(self, request): @@ -436,7 +437,7 @@ class RegisterRestServlet(RestServlet): user_id, device_id, initial_display_name ) - access_token = self.auth_handler.generate_access_token( + access_token = self.macaroon_gen.generate_access_token( user_id, ["guest = true"] ) defer.returnValue((200, { diff --git a/synapse/server.py b/synapse/server.py index 0bfb411269..c577032041 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -37,7 +37,7 @@ from synapse.federation.transport.client import TransportLayerClient from synapse.federation.transaction_queue import TransactionQueue from synapse.handlers import Handlers from synapse.handlers.appservice import ApplicationServicesHandler -from synapse.handlers.auth import AuthHandler +from synapse.handlers.auth import AuthHandler, MacaroonGeneartor from synapse.handlers.devicemessage import DeviceMessageHandler from synapse.handlers.device import DeviceHandler from synapse.handlers.e2e_keys import E2eKeysHandler @@ -131,6 +131,7 @@ class HomeServer(object): 'federation_transport_client', 'federation_sender', 'receipts_handler', + 'macaroon_generator', ] def __init__(self, hostname, **kwargs): @@ -213,6 +214,9 @@ class HomeServer(object): def build_auth_handler(self): return AuthHandler(self) + def build_macaroon_generator(self): + return MacaroonGeneartor(self) + def build_device_handler(self): return DeviceHandler(self) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index b9968debe5..d604e7668f 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -297,6 +297,82 @@ class DataStore(RoomMemberStore, RoomStore, desc="get_user_ip_and_agents", ) + def get_users(self): + """Function to reterive a list of users in users table. + + Args: + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + return self._simple_select_list( + table="users", + keyvalues={}, + retcols=[ + "name", + "password_hash", + "is_guest", + "admin" + ], + desc="get_users", + ) + + def get_users_paginate(self, order, start, limit): + """Function to reterive a paginated list of users from + users list. This will return a json object, which contains + list of users and the total number of users in users table. + + Args: + order (str): column name to order the select by this column + start (int): start number to begin the query from + limit (int): number of rows to reterive + Returns: + defer.Deferred: resolves to json object {list[dict[str, Any]], count} + """ + is_guest = 0 + i_start = (int)(start) + i_limit = (int)(limit) + return self.get_user_list_paginate( + table="users", + keyvalues={ + "is_guest": is_guest + }, + pagevalues=[ + order, + i_limit, + i_start + ], + retcols=[ + "name", + "password_hash", + "is_guest", + "admin" + ], + desc="get_users_paginate", + ) + + def search_users(self, term): + """Function to search users list for one or more users with + the matched term. + + Args: + term (str): search term + col (str): column to query term should be matched to + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + return self._simple_search_list( + table="users", + term=term, + col="name", + retcols=[ + "name", + "password_hash", + "is_guest", + "admin" + ], + desc="search_users", + ) + def are_all_users_on_domain(txn, database_engine, domain): sql = database_engine.convert_param_style( diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 05374682fd..b0dc391190 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -934,6 +934,165 @@ class SQLBaseStore(object): else: return 0 + def _simple_select_list_paginate(self, table, keyvalues, pagevalues, retcols, + desc="_simple_select_list_paginate"): + """Executes a SELECT query on the named table with start and limit, + of row numbers, which may return zero or number of rows from start to limit, + returning the result as a list of dicts. + + Args: + table (str): the table name + keyvalues (dict[str, Any] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + retcols (iterable[str]): the names of the columns to return + order (str): order the select by this column + start (int): start number to begin the query from + limit (int): number of rows to reterive + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + return self.runInteraction( + desc, + self._simple_select_list_paginate_txn, + table, keyvalues, pagevalues, retcols + ) + + @classmethod + def _simple_select_list_paginate_txn(cls, txn, table, keyvalues, pagevalues, retcols): + """Executes a SELECT query on the named table with start and limit, + of row numbers, which may return zero or number of rows from start to limit, + returning the result as a list of dicts. + + Args: + txn : Transaction object + table (str): the table name + keyvalues (dict[str, T] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + pagevalues ([]): + order (str): order the select by this column + start (int): start number to begin the query from + limit (int): number of rows to reterive + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + + """ + if keyvalues: + sql = "SELECT %s FROM %s WHERE %s ORDER BY %s" % ( + ", ".join(retcols), + table, + " AND ".join("%s = ?" % (k,) for k in keyvalues), + " ? ASC LIMIT ? OFFSET ?" + ) + txn.execute(sql, keyvalues.values() + pagevalues) + else: + sql = "SELECT %s FROM %s ORDER BY %s" % ( + ", ".join(retcols), + table, + " ? ASC LIMIT ? OFFSET ?" + ) + txn.execute(sql, pagevalues) + + return cls.cursor_to_dict(txn) + + @defer.inlineCallbacks + def get_user_list_paginate(self, table, keyvalues, pagevalues, retcols, + desc="get_user_list_paginate"): + """Get a list of users from start row to a limit number of rows. This will + return a json object with users and total number of users in users list. + + Args: + table (str): the table name + keyvalues (dict[str, Any] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + pagevalues ([]): + order (str): order the select by this column + start (int): start number to begin the query from + limit (int): number of rows to reterive + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to json object {list[dict[str, Any]], count} + """ + users = yield self.runInteraction( + desc, + self._simple_select_list_paginate_txn, + table, keyvalues, pagevalues, retcols + ) + count = yield self.runInteraction( + desc, + self.get_user_count_txn + ) + retval = { + "users": users, + "total": count + } + defer.returnValue(retval) + + def get_user_count_txn(self, txn): + """Get a total number of registerd users in the users list. + + Args: + txn : Transaction object + Returns: + defer.Deferred: resolves to int + """ + sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;" + txn.execute(sql_count) + count = txn.fetchone()[0] + defer.returnValue(count) + + def _simple_search_list(self, table, term, col, retcols, + desc="_simple_search_list"): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + table (str): the table name + term (str | None): + term for searching the table matched to a column. + col (str): column to query term should be matched to + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to list[dict[str, Any]] or None + """ + + return self.runInteraction( + desc, + self._simple_search_list_txn, + table, term, col, retcols + ) + + @classmethod + def _simple_search_list_txn(cls, txn, table, term, col, retcols): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + txn : Transaction object + table (str): the table name + term (str | None): + term for searching the table matched to a column. + col (str): column to query term should be matched to + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to list[dict[str, Any]] or None + """ + if term: + sql = "SELECT %s FROM %s WHERE %s LIKE ?" % ( + ", ".join(retcols), + table, + col + ) + termvalues = ["%%" + term + "%%"] + txn.execute(sql, termvalues) + else: + return 0 + + return cls.cursor_to_dict(txn) + class _RollbackButIsFineException(Exception): """ This exception is used to rollback a transaction without implying diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 2040e022fa..b9f1365f92 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -93,7 +93,7 @@ class EndToEndKeyStore(SQLBaseStore): query_clause = "user_id = ?" query_params.append(user_id) - if device_id: + if device_id is not None: query_clause += " AND device_id = ?" query_params.append(device_id) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 6685b9da1c..c88f689d3a 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -28,6 +28,7 @@ from synapse.util.metrics import Measure from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError from synapse.state import resolve_events +from synapse.util.caches.descriptors import cached from canonicaljson import encode_canonical_json from collections import deque, namedtuple, OrderedDict @@ -301,7 +302,7 @@ class EventsStore(SQLBaseStore): room_id ) new_latest_event_ids = yield self._calculate_new_extremeties( - room_id, [ev for ev, _ in ev_ctx_rm] + room_id, ev_ctx_rm, latest_event_ids ) if new_latest_event_ids == set(latest_event_ids): @@ -328,27 +329,24 @@ class EventsStore(SQLBaseStore): persist_event_counter.inc_by(len(chunk)) @defer.inlineCallbacks - def _calculate_new_extremeties(self, room_id, events): + def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids): """Calculates the new forward extremeties for a room given events to persist. Assumes that we are only persisting events for one room at a time. """ - latest_event_ids = yield self.get_latest_event_ids_in_room( - room_id - ) new_latest_event_ids = set(latest_event_ids) # First, add all the new events to the list new_latest_event_ids.update( - event.event_id for event in events - if not event.internal_metadata.is_outlier() + event.event_id for event, ctx in event_contexts + if not event.internal_metadata.is_outlier() and not ctx.rejected ) # Now remove all events that are referenced by the to-be-added events new_latest_event_ids.difference_update( e_id - for event in events + for event, ctx in event_contexts for e_id, _ in event.prev_events - if not event.internal_metadata.is_outlier() + if not event.internal_metadata.is_outlier() and not ctx.rejected ) # And finally remove any events that are referenced by previously added @@ -572,14 +570,6 @@ class EventsStore(SQLBaseStore): txn, self.get_users_in_room, (room_id,) ) - # Add an entry to the current_state_resets table to record the point - # where we clobbered the current state - self._simple_insert_txn( - txn, - table="current_state_resets", - values={"event_stream_ordering": max_stream_order} - ) - for room_id, new_extrem in new_forward_extremeties.items(): self._simple_delete_txn( txn, @@ -1579,6 +1569,7 @@ class EventsStore(SQLBaseStore): """The current minimum token that backfilled events have reached""" return -self._backfill_id_gen.get_current_token() + @cached(num_args=5, max_entries=10) def get_all_new_events(self, last_backfill_id, last_forward_id, current_backfill_id, current_forward_id, limit): """Get all the new events that have arrived at the server either as @@ -1611,15 +1602,6 @@ class EventsStore(SQLBaseStore): upper_bound = current_forward_id sql = ( - "SELECT event_stream_ordering FROM current_state_resets" - " WHERE ? < event_stream_ordering" - " AND event_stream_ordering <= ?" - " ORDER BY event_stream_ordering ASC" - ) - txn.execute(sql, (last_forward_id, upper_bound)) - state_resets = txn.fetchall() - - sql = ( "SELECT event_stream_ordering, event_id, state_group" " FROM ex_outlier_stream" " WHERE ? > event_stream_ordering" @@ -1630,7 +1612,6 @@ class EventsStore(SQLBaseStore): forward_ex_outliers = txn.fetchall() else: new_forward_events = [] - state_resets = [] forward_ex_outliers = [] sql = ( @@ -1670,7 +1651,6 @@ class EventsStore(SQLBaseStore): return AllNewEventsResult( new_forward_events, new_backfill_events, forward_ex_outliers, backward_ex_outliers, - state_resets, ) return self.runInteraction("get_all_new_events", get_all_new_events_txn) @@ -1896,5 +1876,4 @@ class EventsStore(SQLBaseStore): AllNewEventsResult = namedtuple("AllNewEventsResult", [ "new_forward_events", "new_backfill_events", "forward_ex_outliers", "backward_ex_outliers", - "state_resets" ]) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 10f7c7a4bc..545d3d3a99 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -66,8 +66,6 @@ class RoomMemberStore(SQLBaseStore): ) for event in events: - txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,)) - txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) txn.call_after( self._membership_stream_cache.entity_has_changed, event.state_key, event.internal_metadata.stream_ordering @@ -131,7 +129,7 @@ class RoomMemberStore(SQLBaseStore): with self._stream_id_gen.get_next() as stream_ordering: yield self.runInteraction("locally_reject_invite", f, stream_ordering) - @cached(max_entries=100000, iterable=True) + @cached(max_entries=500000, iterable=True) def get_users_in_room(self, room_id): def f(txn): @@ -266,7 +264,7 @@ class RoomMemberStore(SQLBaseStore): " ON m.event_id = c.event_id " " AND m.room_id = c.room_id " " AND m.user_id = c.state_key" - " WHERE %(where)s" + " WHERE c.type = 'm.room.member' AND %(where)s" ) % { "where": where_clause, } @@ -276,12 +274,29 @@ class RoomMemberStore(SQLBaseStore): return rows - @cached(max_entries=5000) + @cached(max_entries=500000, iterable=True) def get_rooms_for_user(self, user_id): return self.get_rooms_for_user_where_membership_is( user_id, membership_list=[Membership.JOIN], ) + @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True) + def get_users_who_share_room_with_user(self, user_id, cache_context): + """Returns the set of users who share a room with `user_id` + """ + rooms = yield self.get_rooms_for_user( + user_id, on_invalidate=cache_context.invalidate, + ) + + user_who_share_room = set() + for room in rooms: + user_ids = yield self.get_users_in_room( + room.room_id, on_invalidate=cache_context.invalidate, + ) + user_who_share_room.update(user_ids) + + defer.returnValue(user_who_share_room) + def forget(self, user_id, room_id): """Indicate that user_id wishes to discard history for room_id.""" def f(txn): diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 2dc24951c4..200d124632 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -244,6 +244,20 @@ class StreamStore(SQLBaseStore): defer.returnValue(results) + def get_rooms_that_changed(self, room_ids, from_key): + """Given a list of rooms and a token, return rooms where there may have + been changes. + + Args: + room_ids (list) + from_key (str): The room_key portion of a StreamToken + """ + from_key = RoomStreamToken.parse_stream_token(from_key).stream + return set( + room_id for room_id in room_ids + if self._events_stream_cache.has_entity_changed(room_id, from_key) + ) + @defer.inlineCallbacks def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0, order='DESC'): diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 675bfd5feb..998de70d29 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -478,6 +478,11 @@ class CacheListDescriptor(object): class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): + # We rely on _CacheContext implementing __eq__ and __hash__ sensibly, + # which namedtuple does for us (i.e. two _CacheContext are the same if + # their caches and keys match). This is important in particular to + # dedupe when we add callbacks to lru cache nodes, otherwise the number + # of callbacks would grow. def invalidate(self): self.cache.invalidate(self.key) diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index b94ae369cf..153ef001ad 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -129,11 +129,13 @@ class RetryDestinationLimiter(object): # APIs may expect to never received e.g. a 404. It's important to # handle 404 as some remote servers will return a 404 when the HS # has been decommissioned. + # If we get a 401, then we should probably back off since they + # won't accept our requests for at least a while. + # 429 is us being aggresively rate limited, so lets rate limit + # ourselves. if exc_val.code == 404 and self.backoff_on_404: valid_err_code = False - elif exc_val.code == 429: - # 429 is us being aggresively rate limited, so lets rate limit - # ourselves. + elif exc_val.code in (401, 429): valid_err_code = False elif exc_val.code < 500: valid_err_code = True |