diff options
Diffstat (limited to 'synapse')
64 files changed, 2946 insertions, 1951 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py index e2f84c4d57..183245443c 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -434,31 +434,46 @@ class Auth(object): if event.user_id != invite_event.user_id: return False - try: - public_key = invite_event.content["public_key"] - if signed["mxid"] != event.state_key: - return False - if signed["token"] != token: - return False - for server, signature_block in signed["signatures"].items(): - for key_name, encoded_signature in signature_block.items(): - if not key_name.startswith("ed25519:"): - return False - verify_key = decode_verify_key_bytes( - key_name, - decode_base64(public_key) - ) - verify_signed_json(signed, server, verify_key) - # We got the public key from the invite, so we know that the - # correct server signed the signed bundle. - # The caller is responsible for checking that the signing - # server has not revoked that public key. - return True + if signed["mxid"] != event.state_key: return False - except (KeyError, SignatureVerifyException,): + if signed["token"] != token: return False + for public_key_object in self.get_public_keys(invite_event): + public_key = public_key_object["public_key"] + try: + for server, signature_block in signed["signatures"].items(): + for key_name, encoded_signature in signature_block.items(): + if not key_name.startswith("ed25519:"): + continue + verify_key = decode_verify_key_bytes( + key_name, + decode_base64(public_key) + ) + verify_signed_json(signed, server, verify_key) + + # We got the public key from the invite, so we know that the + # correct server signed the signed bundle. + # The caller is responsible for checking that the signing + # server has not revoked that public key. + return True + except (KeyError, SignatureVerifyException,): + continue + return False + + def get_public_keys(self, invite_event): + public_keys = [] + if "public_key" in invite_event.content: + o = { + "public_key": invite_event.content["public_key"], + } + if "key_validity_url" in invite_event.content: + o["key_validity_url"] = invite_event.content["key_validity_url"] + public_keys.append(o) + public_keys.extend(invite_event.content.get("public_keys", [])) + return public_keys + def _get_power_level_event(self, auth_events): key = (EventTypes.PowerLevels, "", ) return auth_events.get(key) diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 84cbe710b3..8cf4d6169c 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -32,7 +32,6 @@ class PresenceState(object): OFFLINE = u"offline" UNAVAILABLE = u"unavailable" ONLINE = u"online" - FREE_FOR_CHAT = u"free_for_chat" class JoinRules(object): diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 6eff83e5f8..cd699ef27f 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -198,7 +198,10 @@ class Filter(object): sender = event.get("sender", None) if not sender: # Presence events have their 'sender' in content.user_id - sender = event.get("content", {}).get("user_id", None) + content = event.get("content") + # account_data has been allowed to have non-dict content, so check type first + if isinstance(content, dict): + sender = content.get("user_id") return self.check_fields( event.get("room_id", None), diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 2b4be7bdd0..de5ee988f1 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -63,6 +63,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory from synapse.util.logcontext import LoggingContext from synapse.metrics.resource import MetricsResource, METRICS_PREFIX +from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX from synapse.federation.transport.server import TransportLayerServer from synapse import events @@ -169,6 +170,9 @@ class SynapseHomeServer(HomeServer): if name == "metrics" and self.get_config().enable_metrics: resources[METRICS_PREFIX] = MetricsResource(self) + if name == "replication": + resources[REPLICATION_PREFIX] = ReplicationResource(self) + root_resource = create_resource_tree(resources) if tls: reactor.listenSSL( diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 90718192dd..e8bfbe7cb5 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -543,8 +543,19 @@ class FederationServer(FederationBase): return event @defer.inlineCallbacks - def exchange_third_party_invite(self, invite): - ret = yield self.handler.exchange_third_party_invite(invite) + def exchange_third_party_invite( + self, + sender_user_id, + target_user_id, + room_id, + signed, + ): + ret = yield self.handler.exchange_third_party_invite( + sender_user_id, + target_user_id, + room_id, + signed, + ) defer.returnValue(ret) @defer.inlineCallbacks diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 65e054f7dd..6e92e2f8f4 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -425,7 +425,17 @@ class On3pidBindServlet(BaseFederationServlet): last_exception = None for invite in content["invites"]: try: - yield self.handler.exchange_third_party_invite(invite) + if "signed" not in invite or "token" not in invite["signed"]: + message = ("Rejecting received notification of third-" + "party invite without signed: %s" % (invite,)) + logger.info(message) + raise SynapseError(400, message) + yield self.handler.exchange_third_party_invite( + invite["sender"], + invite["mxid"], + invite["room_id"], + invite["signed"], + ) except Exception as e: last_exception = e if last_exception: diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 064e8723c8..bdade98bf7 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import LimitExceededError, SynapseError, AuthError from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.api.constants import Membership, EventTypes -from synapse.types import UserID, RoomAlias +from synapse.types import UserID, RoomAlias, Requester from synapse.push.action_generator import ActionGenerator from synapse.util.logcontext import PreserveLoggingContext @@ -53,9 +53,15 @@ class BaseHandler(object): self.event_builder_factory = hs.get_event_builder_factory() @defer.inlineCallbacks - def _filter_events_for_clients(self, user_tuples, events, event_id_to_state): + def filter_events_for_clients(self, user_tuples, events, event_id_to_state): """ Returns dict of user_id -> list of events that user is allowed to see. + + :param (str, bool) user_tuples: (user id, is_peeking) for each + user to be checked. is_peeking should be true if: + * the user is not currently a member of the room, and: + * the user has not been a member of the room since the given + events """ forgotten = yield defer.gatherResults([ self.store.who_forgot_in_room( @@ -72,18 +78,20 @@ class BaseHandler(object): def allowed(event, user_id, is_peeking): state = event_id_to_state[event.event_id] + # get the room_visibility at the time of the event. visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None) if visibility_event: visibility = visibility_event.content.get("history_visibility", "shared") else: visibility = "shared" + # if it was world_readable, it's easy: everyone can read it if visibility == "world_readable": return True - if is_peeking: - return False - + # get the user's membership at the time of the event. (or rather, + # just *after* the event. Which means that people can see their + # own join events, but not (currently) their own leave events.) membership_event = state.get((EventTypes.Member, user_id), None) if membership_event: if membership_event.event_id in event_id_forgotten: @@ -93,20 +101,29 @@ class BaseHandler(object): else: membership = None + # if the user was a member of the room at the time of the event, + # they can see it. if membership == Membership.JOIN: return True - if event.type == EventTypes.RoomHistoryVisibility: - return not is_peeking + if visibility == "joined": + # we weren't a member at the time of the event, so we can't + # see this event. + return False - if visibility == "shared": - return True - elif visibility == "joined": - return membership == Membership.JOIN elif visibility == "invited": + # user can also see the event if they were *invited* at the time + # of the event. return membership == Membership.INVITE - return True + else: + # visibility is shared: user can also see the event if they have + # become a member since the event + # + # XXX: if the user has subsequently joined and then left again, + # ideally we would share history up to the point they left. But + # we don't know when they left. + return not is_peeking defer.returnValue({ user_id: [ @@ -119,7 +136,17 @@ class BaseHandler(object): @defer.inlineCallbacks def _filter_events_for_client(self, user_id, events, is_peeking=False): - # Assumes that user has at some point joined the room if not is_guest. + """ + Check which events a user is allowed to see + + :param str user_id: user id to be checked + :param [synapse.events.EventBase] events: list of events to be checked + :param bool is_peeking should be True if: + * the user is not currently a member of the room, and: + * the user has not been a member of the room since the given + events + :rtype [synapse.events.EventBase] + """ types = ( (EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id), @@ -128,7 +155,7 @@ class BaseHandler(object): frozenset(e.event_id for e in events), types=types ) - res = yield self._filter_events_for_clients( + res = yield self.filter_events_for_clients( [(user_id, is_peeking)], events, event_id_to_state ) defer.returnValue(res.get(user_id, [])) @@ -147,7 +174,7 @@ class BaseHandler(object): @defer.inlineCallbacks def _create_new_client_event(self, builder): - latest_ret = yield self.store.get_latest_events_in_room( + latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room( builder.room_id, ) @@ -156,7 +183,10 @@ class BaseHandler(object): else: depth = 1 - prev_events = [(e, h) for e, h, _ in latest_ret] + prev_events = [ + (event_id, prev_hashes) + for event_id, prev_hashes, _ in latest_ret + ] builder.prev_events = prev_events builder.depth = depth @@ -165,6 +195,31 @@ class BaseHandler(object): context = yield state_handler.compute_event_context(builder) + # If we've received an invite over federation, there are no latest + # events in the room, because we don't know enough about the graph + # fragment we received to treat it like a graph, so the above returned + # no relevant events. It may have returned some events (if we have + # joined and left the room), but not useful ones, like the invite. So we + # forcibly set our context to the invite we received over federation. + if ( + not self.is_host_in_room(context.current_state) and + builder.type == EventTypes.Member + ): + prev_member_event = yield self.store.get_room_member( + builder.sender, builder.room_id + ) + if prev_member_event: + builder.prev_events = ( + prev_member_event.event_id, + prev_member_event.prev_events + ) + + context = yield state_handler.compute_event_context( + builder, + old_state=(prev_member_event,), + outlier=True + ) + if builder.is_state(): builder.prev_state = yield self.store.add_event_hashes( context.prev_state_events @@ -187,10 +242,33 @@ class BaseHandler(object): (event, context,) ) + def is_host_in_room(self, current_state): + room_members = [ + (state_key, event.membership) + for ((event_type, state_key), event) in current_state.items() + if event_type == EventTypes.Member + ] + if len(room_members) == 0: + # Have we just created the room, and is this about to be the very + # first member event? + create_event = current_state.get(("m.room.create", "")) + if create_event: + return True + for (state_key, membership) in room_members: + if ( + UserID.from_string(state_key).domain == self.hs.hostname + and membership == Membership.JOIN + ): + return True + return False + @defer.inlineCallbacks - def handle_new_client_event(self, event, context, extra_users=[]): + def handle_new_client_event(self, event, context, ratelimit=True, extra_users=[]): # We now need to go and hit out to wherever we need to hit out to. + if ratelimit: + self.ratelimit(event.sender) + self.auth.check(event, auth_events=context.current_state) yield self.maybe_kick_guest_users(event, context.current_state.values()) @@ -215,6 +293,12 @@ class BaseHandler(object): if event.type == EventTypes.Member: if event.content["membership"] == Membership.INVITE: + def is_inviter_member_event(e): + return ( + e.type == EventTypes.Member and + e.sender == event.sender + ) + event.unsigned["invite_room_state"] = [ { "type": e.type, @@ -228,7 +312,7 @@ class BaseHandler(object): EventTypes.CanonicalAlias, EventTypes.RoomAvatar, EventTypes.Name, - ) + ) or is_inviter_member_event(e) ] invitee = UserID.from_string(event.state_key) @@ -316,7 +400,8 @@ class BaseHandler(object): if member_event.type != EventTypes.Member: continue - if not self.hs.is_mine(UserID.from_string(member_event.state_key)): + target_user = UserID.from_string(member_event.state_key) + if not self.hs.is_mine(target_user): continue if member_event.content["membership"] not in { @@ -338,18 +423,13 @@ class BaseHandler(object): # and having homeservers have their own users leave keeps more # of that decision-making and control local to the guest-having # homeserver. - message_handler = self.hs.get_handlers().message_handler - yield message_handler.create_and_send_event( - { - "type": EventTypes.Member, - "state_key": member_event.state_key, - "content": { - "membership": Membership.LEAVE, - "kind": "guest" - }, - "room_id": member_event.room_id, - "sender": member_event.state_key - }, + requester = Requester(target_user, "", True) + handler = self.hs.get_handlers().room_member_handler + yield handler.update_membership( + requester, + target_user, + member_event.room_id, + "leave", ratelimit=False, ) except Exception as e: diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 4efecb1ffd..e0a778e7ff 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -216,7 +216,7 @@ class DirectoryHandler(BaseHandler): aliases = yield self.store.get_aliases_for_room(room_id) msg_handler = self.hs.get_handlers().message_handler - yield msg_handler.create_and_send_event({ + yield msg_handler.create_and_send_nonmember_event({ "type": EventTypes.Aliases, "state_key": self.hs.hostname, "room_id": room_id, diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 4933c31c19..72a31a9755 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -19,6 +19,8 @@ from synapse.util.logutils import log_function from synapse.types import UserID from synapse.events.utils import serialize_event from synapse.util.logcontext import preserve_context_over_fn +from synapse.api.constants import Membership, EventTypes +from synapse.events import EventBase from ._base import BaseHandler @@ -126,11 +128,12 @@ class EventStreamHandler(BaseHandler): If `only_keys` is not None, events from keys will be sent down. """ auth_user = UserID.from_string(auth_user_id) + presence_handler = self.hs.get_handlers().presence_handler - try: - if affect_presence: - yield self.started_stream(auth_user) - + context = yield presence_handler.user_syncing( + auth_user_id, affect_presence=affect_presence, + ) + with context: if timeout: # If they've set a timeout set a minimum limit. timeout = max(timeout, 500) @@ -145,6 +148,34 @@ class EventStreamHandler(BaseHandler): is_guest=is_guest, explicit_room_id=room_id ) + # When the user joins a new room, or another user joins a currently + # joined room, we need to send down presence for those users. + to_add = [] + for event in events: + if not isinstance(event, EventBase): + continue + if event.type == EventTypes.Member: + if event.membership != Membership.JOIN: + continue + # Send down presence. + if event.state_key == auth_user_id: + # Send down presence for everyone in the room. + users = yield self.store.get_users_in_room(event.room_id) + states = yield presence_handler.get_states( + users, + as_event=True, + ) + to_add.extend(states) + else: + + ev = yield presence_handler.get_state( + UserID.from_string(event.state_key), + as_event=True, + ) + to_add.append(ev) + + events.extend(to_add) + time_now = self.clock.time_msec() chunks = [ @@ -159,10 +190,6 @@ class EventStreamHandler(BaseHandler): defer.returnValue(chunk) - finally: - if affect_presence: - self.stopped_stream(auth_user) - class EventHandler(BaseHandler): diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index da55d43541..3655b9e5e2 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -14,6 +14,9 @@ # limitations under the License. """Contains handlers for federation events.""" +from signedjson.key import decode_verify_key_bytes +from signedjson.sign import verify_signed_json +from unpaddedbase64 import decode_base64 from ._base import BaseHandler @@ -1620,19 +1623,15 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks @log_function - def exchange_third_party_invite(self, invite): - sender = invite["sender"] - room_id = invite["room_id"] - - if "signed" not in invite or "token" not in invite["signed"]: - logger.info( - "Discarding received notification of third party invite " - "without signed: %s" % (invite,) - ) - return - + def exchange_third_party_invite( + self, + sender_user_id, + target_user_id, + room_id, + signed, + ): third_party_invite = { - "signed": invite["signed"], + "signed": signed, } event_dict = { @@ -1642,8 +1641,8 @@ class FederationHandler(BaseHandler): "third_party_invite": third_party_invite, }, "room_id": room_id, - "sender": sender, - "state_key": invite["mxid"], + "sender": sender_user_id, + "state_key": target_user_id, } if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)): @@ -1656,11 +1655,11 @@ class FederationHandler(BaseHandler): ) self.auth.check(event, context.current_state) - yield self._validate_keyserver(event, auth_events=context.current_state) + yield self._check_signature(event, auth_events=context.current_state) member_handler = self.hs.get_handlers().room_member_handler - yield member_handler.send_membership_event(event, context) + yield member_handler.send_membership_event(event, context, from_client=False) else: - destinations = set([x.split(":", 1)[-1] for x in (sender, room_id)]) + destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id)) yield self.replication_layer.forward_third_party_invite( destinations, room_id, @@ -1681,13 +1680,13 @@ class FederationHandler(BaseHandler): ) self.auth.check(event, auth_events=context.current_state) - yield self._validate_keyserver(event, auth_events=context.current_state) + yield self._check_signature(event, auth_events=context.current_state) returned_invite = yield self.send_invite(origin, event) # TODO: Make sure the signatures actually are correct. event.signatures.update(returned_invite.signatures) member_handler = self.hs.get_handlers().room_member_handler - yield member_handler.send_membership_event(event, context) + yield member_handler.send_membership_event(event, context, from_client=False) @defer.inlineCallbacks def add_display_name_to_third_party_invite(self, event_dict, event, context): @@ -1711,17 +1710,69 @@ class FederationHandler(BaseHandler): defer.returnValue((event, context)) @defer.inlineCallbacks - def _validate_keyserver(self, event, auth_events): - token = event.content["third_party_invite"]["signed"]["token"] + def _check_signature(self, event, auth_events): + """ + Checks that the signature in the event is consistent with its invite. + :param event (Event): The m.room.member event to check + :param auth_events (dict<(event type, state_key), event>) + + :raises + AuthError if signature didn't match any keys, or key has been + revoked, + SynapseError if a transient error meant a key couldn't be checked + for revocation. + """ + signed = event.content["third_party_invite"]["signed"] + token = signed["token"] invite_event = auth_events.get( (EventTypes.ThirdPartyInvite, token,) ) + if not invite_event: + raise AuthError(403, "Could not find invite") + + last_exception = None + for public_key_object in self.hs.get_auth().get_public_keys(invite_event): + try: + for server, signature_block in signed["signatures"].items(): + for key_name, encoded_signature in signature_block.items(): + if not key_name.startswith("ed25519:"): + continue + + public_key = public_key_object["public_key"] + verify_key = decode_verify_key_bytes( + key_name, + decode_base64(public_key) + ) + verify_signed_json(signed, server, verify_key) + if "key_validity_url" in public_key_object: + yield self._check_key_revocation( + public_key, + public_key_object["key_validity_url"] + ) + return + except Exception as e: + last_exception = e + raise last_exception + + @defer.inlineCallbacks + def _check_key_revocation(self, public_key, url): + """ + Checks whether public_key has been revoked. + + :param public_key (str): base-64 encoded public key. + :param url (str): Key revocation URL. + + :raises + AuthError if they key has been revoked. + SynapseError if a transient error meant a key couldn't be checked + for revocation. + """ try: response = yield self.hs.get_simple_http_client().get_json( - invite_event.content["key_validity_url"], - {"public_key": invite_event.content["public_key"]} + url, + {"public_key": public_key} ) except Exception: raise SynapseError( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 82c8cb5f0c..afa7c9c36c 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -16,12 +16,11 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import AuthError, Codes +from synapse.api.errors import AuthError, Codes, SynapseError from synapse.streams.config import PaginationConfig from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.util import unwrapFirstError -from synapse.util.logcontext import PreserveLoggingContext from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.types import UserID, RoomStreamToken, StreamToken @@ -216,7 +215,7 @@ class MessageHandler(BaseHandler): defer.returnValue((event, context)) @defer.inlineCallbacks - def send_event(self, event, context, ratelimit=True, is_guest=False): + def send_nonmember_event(self, event, context, ratelimit=True): """ Persists and notifies local clients and federation of an event. @@ -226,55 +225,68 @@ class MessageHandler(BaseHandler): ratelimit (bool): Whether to rate limit this send. is_guest (bool): Whether the sender is a guest. """ + if event.type == EventTypes.Member: + raise SynapseError( + 500, + "Tried to send member event through non-member codepath" + ) + user = UserID.from_string(event.sender) assert self.hs.is_mine(user), "User must be our own: %s" % (user,) - if ratelimit: - self.ratelimit(event.sender) - if event.is_state(): - prev_state = context.current_state.get((event.type, event.state_key)) - if prev_state and event.user_id == prev_state.user_id: - prev_content = encode_canonical_json(prev_state.content) - next_content = encode_canonical_json(event.content) - if prev_content == next_content: - # Duplicate suppression for state updates with same sender - # and content. - defer.returnValue(prev_state) + prev_state = self.deduplicate_state_event(event, context) + if prev_state is not None: + defer.returnValue(prev_state) - if event.type == EventTypes.Member: - member_handler = self.hs.get_handlers().room_member_handler - yield member_handler.send_membership_event(event, context, is_guest=is_guest) - else: - yield self.handle_new_client_event( - event=event, - context=context, - ) + yield self.handle_new_client_event( + event=event, + context=context, + ratelimit=ratelimit, + ) if event.type == EventTypes.Message: presence = self.hs.get_handlers().presence_handler - with PreserveLoggingContext(): - presence.bump_presence_active_time(user) + yield presence.bump_presence_active_time(user) + + def deduplicate_state_event(self, event, context): + """ + Checks whether event is in the latest resolved state in context. + + If so, returns the version of the event in context. + Otherwise, returns None. + """ + prev_event = context.current_state.get((event.type, event.state_key)) + if prev_event and event.user_id == prev_event.user_id: + prev_content = encode_canonical_json(prev_event.content) + next_content = encode_canonical_json(event.content) + if prev_content == next_content: + return prev_event + return None @defer.inlineCallbacks - def create_and_send_event(self, event_dict, ratelimit=True, - token_id=None, txn_id=None, is_guest=False): + def create_and_send_nonmember_event( + self, + event_dict, + ratelimit=True, + token_id=None, + txn_id=None + ): """ Creates an event, then sends it. - See self.create_event and self.send_event. + See self.create_event and self.send_nonmember_event. """ event, context = yield self.create_event( event_dict, token_id=token_id, txn_id=txn_id ) - yield self.send_event( + yield self.send_nonmember_event( event, context, ratelimit=ratelimit, - is_guest=is_guest ) defer.returnValue(event) @@ -660,10 +672,6 @@ class MessageHandler(BaseHandler): room_id=room_id, ) - # TODO(paul): I wish I was called with user objects not user_id - # strings... - auth_user = UserID.from_string(user_id) - # TODO: These concurrently time_now = self.clock.time_msec() state = [ @@ -688,13 +696,11 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def get_presence(): states = yield presence_handler.get_states( - target_users=[UserID.from_string(m.user_id) for m in room_members], - auth_user=auth_user, + [m.user_id for m in room_members], as_event=True, - check_auth=False, ) - defer.returnValue(states.values()) + defer.returnValue(states) @defer.inlineCallbacks def get_receipts(): diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index b61394f2b5..f6cf343174 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -13,13 +13,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +"""This module is responsible for keeping track of presence status of local +and remote users. -from synapse.api.errors import SynapseError, AuthError +The methods that define policy are: + - PresenceHandler._update_states + - PresenceHandler._handle_timeouts + - should_notify +""" + +from twisted.internet import defer, reactor +from contextlib import contextmanager + +from synapse.api.errors import SynapseError from synapse.api.constants import PresenceState +from synapse.storage.presence import UserPresenceState -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logcontext import preserve_fn from synapse.util.logutils import log_function +from synapse.util.metrics import Measure +from synapse.util.wheel_timer import WheelTimer from synapse.types import UserID import synapse.metrics @@ -32,34 +45,32 @@ logger = logging.getLogger(__name__) metrics = synapse.metrics.get_metrics_for(__name__) +notified_presence_counter = metrics.register_counter("notified_presence") +federation_presence_out_counter = metrics.register_counter("federation_presence_out") +presence_updates_counter = metrics.register_counter("presence_updates") +timers_fired_counter = metrics.register_counter("timers_fired") +federation_presence_counter = metrics.register_counter("federation_presence") +bump_active_time_counter = metrics.register_counter("bump_active_time") -# Don't bother bumping "last active" time if it differs by less than 60 seconds -LAST_ACTIVE_GRANULARITY = 60 * 1000 - -# Keep no more than this number of offline serial revisions -MAX_OFFLINE_SERIALS = 1000 +# If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them +# "currently_active" +LAST_ACTIVE_GRANULARITY = 60 * 1000 -# TODO(paul): Maybe there's one of these I can steal from somewhere -def partition(l, func): - """Partition the list by the result of func applied to each element.""" - ret = {} - - for x in l: - key = func(x) - if key not in ret: - ret[key] = [] - ret[key].append(x) +# How long to wait until a new /events or /sync request before assuming +# the client has gone. +SYNC_ONLINE_TIMEOUT = 30 * 1000 - return ret +# How long to wait before marking the user as idle. Compared against last active +IDLE_TIMER = 5 * 60 * 1000 +# How often we expect remote servers to resend us presence. +FEDERATION_TIMEOUT = 30 * 60 * 1000 -def partitionbool(l, func): - def boolfunc(x): - return bool(func(x)) +# How often to resend presence to remote servers +FEDERATION_PING_INTERVAL = 25 * 60 * 1000 - ret = partition(l, boolfunc) - return ret.get(True, []), ret.get(False, []) +assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER def user_presence_changed(distributor, user, statuscache): @@ -72,45 +83,13 @@ def collect_presencelike_data(distributor, user, content): class PresenceHandler(BaseHandler): - STATE_LEVELS = { - PresenceState.OFFLINE: 0, - PresenceState.UNAVAILABLE: 1, - PresenceState.ONLINE: 2, - PresenceState.FREE_FOR_CHAT: 3, - } - def __init__(self, hs): super(PresenceHandler, self).__init__(hs) - - self.homeserver = hs - + self.hs = hs self.clock = hs.get_clock() - - distributor = hs.get_distributor() - distributor.observe("registered_user", self.registered_user) - - distributor.observe( - "started_user_eventstream", self.started_user_eventstream - ) - distributor.observe( - "stopped_user_eventstream", self.stopped_user_eventstream - ) - - distributor.observe("user_joined_room", self.user_joined_room) - - distributor.declare("collect_presencelike_data") - - distributor.declare("changed_presencelike_data") - distributor.observe( - "changed_presencelike_data", self.changed_presencelike_data - ) - - # outbound signal from the presence module to advertise when a user's - # presence has changed - distributor.declare("user_presence_changed") - - self.distributor = distributor - + self.store = hs.get_datastore() + self.wheel_timer = WheelTimer() + self.notifier = hs.get_notifier() self.federation = hs.get_replication_layer() self.federation.register_edu_handler( @@ -138,348 +117,552 @@ class PresenceHandler(BaseHandler): ) ) - # IN-MEMORY store, mapping local userparts to sets of local users to - # be informed of state changes. - self._local_pushmap = {} - # map local users to sets of remote /domain names/ who are interested - # in them - self._remote_sendmap = {} - # map remote users to sets of local users who're interested in them - self._remote_recvmap = {} - # list of (serial, set of(userids)) tuples, ordered by serial, latest - # first - self._remote_offline_serials = [] - - # map any user to a UserPresenceCache - self._user_cachemap = {} - self._user_cachemap_latest_serial = 0 - - # map room_ids to the latest presence serial for a member of that - # room - self._room_serials = {} + distributor = hs.get_distributor() + distributor.observe("user_joined_room", self.user_joined_room) + + active_presence = self.store.take_presence_startup_info() + + # A dictionary of the current state of users. This is prefilled with + # non-offline presence from the DB. We should fetch from the DB if + # we can't find a users presence in here. + self.user_to_current_state = { + state.user_id: state + for state in active_presence + } metrics.register_callback( - "userCachemap:size", - lambda: len(self._user_cachemap), + "user_to_current_state_size", lambda: len(self.user_to_current_state) ) - def _get_or_make_usercache(self, user): - """If the cache entry doesn't exist, initialise a new one.""" - if user not in self._user_cachemap: - self._user_cachemap[user] = UserPresenceCache() - return self._user_cachemap[user] - - def _get_or_offline_usercache(self, user): - """If the cache entry doesn't exist, return an OFFLINE one but do not - store it into the cache.""" - if user in self._user_cachemap: - return self._user_cachemap[user] - else: - return UserPresenceCache() + now = self.clock.time_msec() + for state in active_presence: + self.wheel_timer.insert( + now=now, + obj=state.user_id, + then=state.last_active_ts + IDLE_TIMER, + ) + self.wheel_timer.insert( + now=now, + obj=state.user_id, + then=state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT, + ) + if self.hs.is_mine_id(state.user_id): + self.wheel_timer.insert( + now=now, + obj=state.user_id, + then=state.last_federation_update_ts + FEDERATION_PING_INTERVAL, + ) + else: + self.wheel_timer.insert( + now=now, + obj=state.user_id, + then=state.last_federation_update_ts + FEDERATION_TIMEOUT, + ) - def registered_user(self, user): - return self.store.create_presence(user.localpart) + # Set of users who have presence in the `user_to_current_state` that + # have not yet been persisted + self.unpersisted_users_changes = set() - @defer.inlineCallbacks - def is_presence_visible(self, observer_user, observed_user): - assert(self.hs.is_mine(observed_user)) + reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown) - if observer_user == observed_user: - defer.returnValue(True) + self.serial_to_user = {} + self._next_serial = 1 - if (yield self.store.user_rooms_intersect( - [u.to_string() for u in observer_user, observed_user])): - defer.returnValue(True) + # Keeps track of the number of *ongoing* syncs. While this is non zero + # a user will never go offline. + self.user_to_num_current_syncs = {} - if (yield self.store.is_presence_visible( - observed_localpart=observed_user.localpart, - observer_userid=observer_user.to_string())): - defer.returnValue(True) + # Start a LoopingCall in 30s that fires every 5s. + # The initial delay is to allow disconnected clients a chance to + # reconnect before we treat them as offline. + self.clock.call_later( + 0 * 1000, + self.clock.looping_call, + self._handle_timeouts, + 5000, + ) - defer.returnValue(False) + metrics.register_callback("wheel_timer_size", lambda: len(self.wheel_timer)) @defer.inlineCallbacks - def get_state(self, target_user, auth_user, as_event=False, check_auth=True): - """Get the current presence state of the given user. + def _on_shutdown(self): + """Gets called when shutting down. This lets us persist any updates that + we haven't yet persisted, e.g. updates that only changes some internal + timers. This allows changes to persist across startup without having to + persist every single change. + + If this does not run it simply means that some of the timers will fire + earlier than they should when synapse is restarted. This affect of this + is some spurious presence changes that will self-correct. + """ + logger.info( + "Performing _on_shutdown. Persiting %d unpersisted changes", + len(self.user_to_current_state) + ) - Args: - target_user (UserID): The user whose presence we want - auth_user (UserID): The user requesting the presence, used for - checking if said user is allowed to see the persence of the - `target_user` - as_event (bool): Format the return as an event or not? - check_auth (bool): Perform the auth checks or not? + if self.unpersisted_users_changes: + yield self.store.update_presence([ + self.user_to_current_state[user_id] + for user_id in self.unpersisted_users_changes + ]) + logger.info("Finished _on_shutdown") - Returns: - dict: The presence state of the `target_user`, whose format depends - on the `as_event` argument. + @defer.inlineCallbacks + def _update_states(self, new_states): + """Updates presence of users. Sets the appropriate timeouts. Pokes + the notifier and federation if and only if the changed presence state + should be sent to clients/servers. """ - if self.hs.is_mine(target_user): - if check_auth: - visible = yield self.is_presence_visible( - observer_user=auth_user, - observed_user=target_user - ) + now = self.clock.time_msec() - if not visible: - raise SynapseError(404, "Presence information not visible") + with Measure(self.clock, "presence_update_states"): - if target_user in self._user_cachemap: - state = self._user_cachemap[target_user].get_state() - else: - state = yield self.store.get_presence_state(target_user.localpart) - if "mtime" in state: - del state["mtime"] - state["presence"] = state.pop("state") - else: - # TODO(paul): Have remote server send us permissions set - state = self._get_or_offline_usercache(target_user).get_state() + # NOTE: We purposefully don't yield between now and when we've + # calculated what we want to do with the new states, to avoid races. - if "last_active" in state: - state["last_active_ago"] = int( - self.clock.time_msec() - state.pop("last_active") - ) + to_notify = {} # Changes we want to notify everyone about + to_federation_ping = {} # These need sending keep-alives - if as_event: - content = state + for new_state in new_states: + user_id = new_state.user_id - content["user_id"] = target_user.to_string() + # Its fine to not hit the database here, as the only thing not in + # the current state cache are OFFLINE states, where the only field + # of interest is last_active which is safe enough to assume is 0 + # here. + prev_state = self.user_to_current_state.get( + user_id, UserPresenceState.default(user_id) + ) - if "last_active" in content: - content["last_active_ago"] = int( - self._clock.time_msec() - content.pop("last_active") + new_state, should_notify, should_ping = handle_update( + prev_state, new_state, + is_mine=self.hs.is_mine_id(user_id), + wheel_timer=self.wheel_timer, + now=now ) - defer.returnValue({"type": "m.presence", "content": content}) - else: - defer.returnValue(state) + self.user_to_current_state[user_id] = new_state - @defer.inlineCallbacks - def get_states(self, target_users, auth_user, as_event=False, check_auth=True): - """A batched version of the `get_state` method that accepts a list of - `target_users` + if should_notify: + to_notify[user_id] = new_state + elif should_ping: + to_federation_ping[user_id] = new_state - Args: - target_users (list): The list of UserID's whose presence we want - auth_user (UserID): The user requesting the presence, used for - checking if said user is allowed to see the persence of the - `target_users` - as_event (bool): Format the return as an event or not? - check_auth (bool): Perform the auth checks or not? + # TODO: We should probably ensure there are no races hereafter - Returns: - dict: A mapping from user -> presence_state - """ - local_users, remote_users = partitionbool( - target_users, - lambda u: self.hs.is_mine(u) - ) + presence_updates_counter.inc_by(len(new_states)) + + if to_notify: + notified_presence_counter.inc_by(len(to_notify)) + yield self._persist_and_notify(to_notify.values()) + + self.unpersisted_users_changes |= set(s.user_id for s in new_states) + self.unpersisted_users_changes -= set(to_notify.keys()) - if check_auth: - for user in local_users: - visible = yield self.is_presence_visible( - observer_user=auth_user, - observed_user=user + to_federation_ping = { + user_id: state for user_id, state in to_federation_ping.items() + if user_id not in to_notify + } + if to_federation_ping: + federation_presence_out_counter.inc_by(len(to_federation_ping)) + + _, _, hosts_to_states = yield self._get_interested_parties( + to_federation_ping.values() ) - if not visible: - raise SynapseError(404, "Presence information not visible") + self._push_to_remotes(hosts_to_states) + + def _handle_timeouts(self): + """Checks the presence of users that have timed out and updates as + appropriate. + """ + now = self.clock.time_msec() + + with Measure(self.clock, "presence_handle_timeouts"): + # Fetch the list of users that *may* have timed out. Things may have + # changed since the timeout was set, so we won't necessarily have to + # take any action. + users_to_check = self.wheel_timer.fetch(now) - results = {} - if local_users: - for user in local_users: - if user in self._user_cachemap: - results[user] = self._user_cachemap[user].get_state() + states = [ + self.user_to_current_state.get( + user_id, UserPresenceState.default(user_id) + ) + for user_id in set(users_to_check) + ] - local_to_user = {u.localpart: u for u in local_users} + timers_fired_counter.inc_by(len(states)) - states = yield self.store.get_presence_states( - [u.localpart for u in local_users if u not in results] + changes = handle_timeouts( + states, + is_mine_fn=self.hs.is_mine_id, + user_to_num_current_syncs=self.user_to_num_current_syncs, + now=now, ) - for local_part, state in states.items(): - if state is None: - continue - res = {"presence": state["state"]} - if "status_msg" in state and state["status_msg"]: - res["status_msg"] = state["status_msg"] - results[local_to_user[local_part]] = res - - for user in remote_users: - # TODO(paul): Have remote server send us permissions set - results[user] = self._get_or_offline_usercache(user).get_state() - - for state in results.values(): - if "last_active" in state: - state["last_active_ago"] = int( - self.clock.time_msec() - state.pop("last_active") - ) + preserve_fn(self._update_states)(changes) - if as_event: - for user, state in results.items(): - content = state - content["user_id"] = user.to_string() + @defer.inlineCallbacks + def bump_presence_active_time(self, user): + """We've seen the user do something that indicates they're interacting + with the app. + """ + user_id = user.to_string() - if "last_active" in content: - content["last_active_ago"] = int( - self._clock.time_msec() - content.pop("last_active") - ) + bump_active_time_counter.inc() - results[user] = {"type": "m.presence", "content": content} + prev_state = yield self.current_state_for_user(user_id) - defer.returnValue(results) + new_fields = { + "last_active_ts": self.clock.time_msec(), + } + if prev_state.state == PresenceState.UNAVAILABLE: + new_fields["state"] = PresenceState.ONLINE + + yield self._update_states([prev_state.copy_and_replace(**new_fields)]) @defer.inlineCallbacks - @log_function - def set_state(self, target_user, auth_user, state): - # return - # TODO (erikj): Turn this back on. Why did we end up sending EDUs - # everywhere? + def user_syncing(self, user_id, affect_presence=True): + """Returns a context manager that should surround any stream requests + from the user. - if not self.hs.is_mine(target_user): - raise SynapseError(400, "User is not hosted on this Home Server") + This allows us to keep track of who is currently streaming and who isn't + without having to have timers outside of this module to avoid flickering + when users disconnect/reconnect. - if target_user != auth_user: - raise AuthError(400, "Cannot set another user's presence") + Args: + user_id (str) + affect_presence (bool): If false this function will be a no-op. + Useful for streams that are not associated with an actual + client that is being used by a user. + """ + if affect_presence: + curr_sync = self.user_to_num_current_syncs.get(user_id, 0) + self.user_to_num_current_syncs[user_id] = curr_sync + 1 + + prev_state = yield self.current_state_for_user(user_id) + if prev_state.state == PresenceState.OFFLINE: + # If they're currently offline then bring them online, otherwise + # just update the last sync times. + yield self._update_states([prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=self.clock.time_msec(), + last_user_sync_ts=self.clock.time_msec(), + )]) + else: + yield self._update_states([prev_state.copy_and_replace( + last_user_sync_ts=self.clock.time_msec(), + )]) - if "status_msg" not in state: - state["status_msg"] = None + @defer.inlineCallbacks + def _end(): + if affect_presence: + self.user_to_num_current_syncs[user_id] -= 1 - for k in state.keys(): - if k not in ("presence", "status_msg"): - raise SynapseError( - 400, "Unexpected presence state key '%s'" % (k,) - ) + prev_state = yield self.current_state_for_user(user_id) + yield self._update_states([prev_state.copy_and_replace( + last_user_sync_ts=self.clock.time_msec(), + )]) - if state["presence"] not in self.STATE_LEVELS: - raise SynapseError(400, "'%s' is not a valid presence state" % ( - state["presence"], - )) + @contextmanager + def _user_syncing(): + try: + yield + finally: + preserve_fn(_end)() - logger.debug("Updating presence state of %s to %s", - target_user.localpart, state["presence"]) + defer.returnValue(_user_syncing()) - state_to_store = dict(state) - state_to_store["state"] = state_to_store.pop("presence") + @defer.inlineCallbacks + def current_state_for_user(self, user_id): + """Get the current presence state for a user. + """ + res = yield self.current_state_for_users([user_id]) + defer.returnValue(res[user_id]) - statuscache = self._get_or_offline_usercache(target_user) - was_level = self.STATE_LEVELS[statuscache.get_state()["presence"]] - now_level = self.STATE_LEVELS[state["presence"]] + @defer.inlineCallbacks + def current_state_for_users(self, user_ids): + """Get the current presence state for multiple users. - yield self.store.set_presence_state( - target_user.localpart, state_to_store - ) - yield collect_presencelike_data(self.distributor, target_user, state) + Returns: + dict: `user_id` -> `UserPresenceState` + """ + states = { + user_id: self.user_to_current_state.get(user_id, None) + for user_id in user_ids + } + + missing = [user_id for user_id, state in states.items() if not state] + if missing: + # There are things not in our in memory cache. Lets pull them out of + # the database. + res = yield self.store.get_presence_for_users(missing) + states.update({state.user_id: state for state in res}) + + missing = [user_id for user_id, state in states.items() if not state] + if missing: + new = { + user_id: UserPresenceState.default(user_id) + for user_id in missing + } + states.update(new) + self.user_to_current_state.update(new) + + defer.returnValue(states) + + @defer.inlineCallbacks + def _get_interested_parties(self, states): + """Given a list of states return which entities (rooms, users, servers) + are interested in the given states. + + Returns: + 3-tuple: `(room_ids_to_states, users_to_states, hosts_to_states)`, + with each item being a dict of `entity_name` -> `[UserPresenceState]` + """ + room_ids_to_states = {} + users_to_states = {} + for state in states: + events = yield self.store.get_rooms_for_user(state.user_id) + for e in events: + room_ids_to_states.setdefault(e.room_id, []).append(state) + + plist = yield self.store.get_presence_list_observers_accepted(state.user_id) + for u in plist: + users_to_states.setdefault(u, []).append(state) + + # Always notify self + users_to_states.setdefault(state.user_id, []).append(state) + + hosts_to_states = {} + for room_id, states in room_ids_to_states.items(): + local_states = filter(lambda s: self.hs.is_mine_id(s.user_id), states) + if not local_states: + continue - if now_level > was_level: - state["last_active"] = self.clock.time_msec() + hosts = yield self.store.get_joined_hosts_for_room(room_id) + for host in hosts: + hosts_to_states.setdefault(host, []).extend(local_states) - now_online = state["presence"] != PresenceState.OFFLINE - was_polling = target_user in self._user_cachemap + for user_id, states in users_to_states.items(): + local_states = filter(lambda s: self.hs.is_mine_id(s.user_id), states) + if not local_states: + continue - if now_online and not was_polling: - yield self.start_polling_presence(target_user, state=state) - elif not now_online and was_polling: - yield self.stop_polling_presence(target_user) + host = UserID.from_string(user_id).domain + hosts_to_states.setdefault(host, []).extend(local_states) - # TODO(paul): perform a presence push as part of start/stop poll so - # we don't have to do this all the time - yield self.changed_presencelike_data(target_user, state) + # TODO: de-dup hosts_to_states, as a single host might have multiple + # of same presence - def bump_presence_active_time(self, user, now=None): - if now is None: - now = self.clock.time_msec() + defer.returnValue((room_ids_to_states, users_to_states, hosts_to_states)) - prev_state = self._get_or_make_usercache(user) - if now - prev_state.state.get("last_active", 0) < LAST_ACTIVE_GRANULARITY: - return + @defer.inlineCallbacks + def _persist_and_notify(self, states): + """Persist states in the database, poke the notifier and send to + interested remote servers + """ + stream_id, max_token = yield self.store.update_presence(states) + + parties = yield self._get_interested_parties(states) + room_ids_to_states, users_to_states, hosts_to_states = parties + + self.notifier.on_new_event( + "presence_key", stream_id, rooms=room_ids_to_states.keys(), + users=[UserID.from_string(u) for u in users_to_states.keys()] + ) - with PreserveLoggingContext(): - self.changed_presencelike_data(user, {"last_active": now}) + self._push_to_remotes(hosts_to_states) - def get_joined_rooms_for_user(self, user): - """Get the list of rooms a user is joined to. + def _push_to_remotes(self, hosts_to_states): + """Sends state updates to remote servers. Args: - user(UserID): The user. - Returns: - A Deferred of a list of room id strings. + hosts_to_states (dict): Mapping `server_name` -> `[UserPresenceState]` """ - rm_handler = self.homeserver.get_handlers().room_member_handler - return rm_handler.get_joined_rooms_for_user(user) + now = self.clock.time_msec() + for host, states in hosts_to_states.items(): + self.federation.send_edu( + destination=host, + edu_type="m.presence", + content={ + "push": [ + _format_user_presence_state(state, now) + for state in states + ] + } + ) - def get_joined_users_for_room_id(self, room_id): - rm_handler = self.homeserver.get_handlers().room_member_handler - return rm_handler.get_room_members(room_id) + @defer.inlineCallbacks + def incoming_presence(self, origin, content): + """Called when we receive a `m.presence` EDU from a remote server. + """ + now = self.clock.time_msec() + updates = [] + for push in content.get("push", []): + # A "push" contains a list of presence that we are probably interested + # in. + # TODO: Actually check if we're interested, rather than blindly + # accepting presence updates. + user_id = push.get("user_id", None) + if not user_id: + logger.info( + "Got presence update from %r with no 'user_id': %r", + origin, push, + ) + continue + + presence_state = push.get("presence", None) + if not presence_state: + logger.info( + "Got presence update from %r with no 'presence_state': %r", + origin, push, + ) + continue + + new_fields = { + "state": presence_state, + "last_federation_update_ts": now, + } + + last_active_ago = push.get("last_active_ago", None) + if last_active_ago is not None: + new_fields["last_active_ts"] = now - last_active_ago + + new_fields["status_msg"] = push.get("status_msg", None) + new_fields["currently_active"] = push.get("currently_active", False) + + prev_state = yield self.current_state_for_user(user_id) + updates.append(prev_state.copy_and_replace(**new_fields)) + + if updates: + federation_presence_counter.inc_by(len(updates)) + yield self._update_states(updates) @defer.inlineCallbacks - def changed_presencelike_data(self, user, state): - """Updates the presence state of a local user. + def get_state(self, target_user, as_event=False): + results = yield self.get_states( + [target_user.to_string()], + as_event=as_event, + ) + + defer.returnValue(results[0]) + + @defer.inlineCallbacks + def get_states(self, target_user_ids, as_event=False): + """Get the presence state for users. Args: - user(UserID): The user being updated. - state(dict): The new presence state for the user. + target_user_ids (list) + as_event (bool): Whether to format it as a client event or not. + Returns: - A Deferred + list """ - self._user_cachemap_latest_serial += 1 - statuscache = yield self.update_presence_cache(user, state) - yield self.push_presence(user, statuscache=statuscache) - @log_function - def started_user_eventstream(self, user): - # TODO(paul): Use "last online" state - return self.set_state(user, user, {"presence": PresenceState.ONLINE}) + updates = yield self.current_state_for_users(target_user_ids) + updates = updates.values() - @log_function - def stopped_user_eventstream(self, user): - # TODO(paul): Save current state as "last online" state - return self.set_state(user, user, {"presence": PresenceState.OFFLINE}) + for user_id in set(target_user_ids) - set(u.user_id for u in updates): + updates.append(UserPresenceState.default(user_id)) + + now = self.clock.time_msec() + if as_event: + defer.returnValue([ + { + "type": "m.presence", + "content": _format_user_presence_state(state, now), + } + for state in updates + ]) + else: + defer.returnValue([ + _format_user_presence_state(state, now) for state in updates + ]) @defer.inlineCallbacks - def user_joined_room(self, user, room_id): - """Called via the distributor whenever a user joins a room. - Notifies the new member of the presence of the current members. - Notifies the current members of the room of the new member's presence. + def set_state(self, target_user, state): + """Set the presence state of the user. + """ + status_msg = state.get("status_msg", None) + presence = state["presence"] - Args: - user(UserID): The user who joined the room. - room_id(str): The room id the user joined. + valid_presence = ( + PresenceState.ONLINE, PresenceState.UNAVAILABLE, PresenceState.OFFLINE + ) + if presence not in valid_presence: + raise SynapseError(400, "Invalid presence state") + + user_id = target_user.to_string() + + prev_state = yield self.current_state_for_user(user_id) + + new_fields = { + "state": presence, + "status_msg": status_msg if presence != PresenceState.OFFLINE else None + } + + if presence == PresenceState.ONLINE: + new_fields["last_active_ts"] = self.clock.time_msec() + + yield self._update_states([prev_state.copy_and_replace(**new_fields)]) + + @defer.inlineCallbacks + def user_joined_room(self, user, room_id): + """Called (via the distributor) when a user joins a room. This funciton + sends presence updates to servers, either: + 1. the joining user is a local user and we send their presence to + all servers in the room. + 2. the joining user is a remote user and so we send presence for all + local users in the room. """ + # We only need to send presence to servers that don't have it yet. We + # don't need to send to local clients here, as that is done as part + # of the event stream/sync. + # TODO: Only send to servers not already in the room. if self.hs.is_mine(user): - # No actual update but we need to bump the serial anyway for the - # event source - self._user_cachemap_latest_serial += 1 - statuscache = yield self.update_presence_cache( - user, room_ids=[room_id] - ) - self.push_update_to_local_and_remote( - observed_user=user, - room_ids=[room_id], - statuscache=statuscache, - ) + state = yield self.current_state_for_user(user.to_string()) - # We also want to tell them about current presence of people. - curr_users = yield self.get_joined_users_for_room_id(room_id) + hosts = yield self.store.get_joined_hosts_for_room(room_id) + self._push_to_remotes({host: (state,) for host in hosts}) + else: + user_ids = yield self.store.get_users_in_room(room_id) + user_ids = filter(self.hs.is_mine_id, user_ids) - for local_user in [c for c in curr_users if self.hs.is_mine(c)]: - statuscache = yield self.update_presence_cache( - local_user, room_ids=[room_id], add_to_cache=False - ) + states = yield self.current_state_for_users(user_ids) - with PreserveLoggingContext(): - self.push_update_to_local_and_remote( - observed_user=local_user, - users_to_push=[user], - statuscache=statuscache, - ) + self._push_to_remotes({user.domain: states.values()}) @defer.inlineCallbacks - def send_presence_invite(self, observer_user, observed_user): - """Request the presence of a local or remote user for a local user""" + def get_presence_list(self, observer_user, accepted=None): + """Returns the presence for all users in their presence list. + """ if not self.hs.is_mine(observer_user): raise SynapseError(400, "User is not hosted on this Home Server") + presence_list = yield self.store.get_presence_list( + observer_user.localpart, accepted=accepted + ) + + results = yield self.get_states( + target_user_ids=[row["observed_user_id"] for row in presence_list], + as_event=False, + ) + + is_accepted = { + row["observed_user_id"]: row["accepted"] for row in presence_list + } + + for result in results: + result.update({ + "accepted": is_accepted, + }) + + defer.returnValue(results) + + @defer.inlineCallbacks + def send_presence_invite(self, observer_user, observed_user): + """Sends a presence invite. + """ yield self.store.add_presence_list_pending( observer_user.localpart, observed_user.to_string() ) @@ -497,59 +680,40 @@ class PresenceHandler(BaseHandler): ) @defer.inlineCallbacks - def _should_accept_invite(self, observed_user, observer_user): - if not self.hs.is_mine(observed_user): - defer.returnValue(False) - - row = yield self.store.has_presence_state(observed_user.localpart) - if not row: - defer.returnValue(False) - - # TODO(paul): Eventually we'll ask the user's permission for this - # before accepting. For now just accept any invite request - defer.returnValue(True) - - @defer.inlineCallbacks def invite_presence(self, observed_user, observer_user): - """Handles a m.presence_invite EDU. A remote or local user has - requested presence updates for a local user. If the invite is accepted - then allow the local or remote user to see the presence of the local - user. - - Args: - observed_user(UserID): The local user whose presence is requested. - observer_user(UserID): The remote or local user requesting presence. + """Handles new presence invites. """ - accept = yield self._should_accept_invite(observed_user, observer_user) - - if accept: - yield self.store.allow_presence_visible( - observed_user.localpart, observer_user.to_string() - ) + if not self.hs.is_mine(observed_user): + raise SynapseError(400, "User is not hosted on this Home Server") + # TODO: Don't auto accept if self.hs.is_mine(observer_user): - if accept: - yield self.accept_presence(observed_user, observer_user) - else: - yield self.deny_presence(observed_user, observer_user) + yield self.accept_presence(observed_user, observer_user) else: - edu_type = "m.presence_accept" if accept else "m.presence_deny" - - yield self.federation.send_edu( + self.federation.send_edu( destination=observer_user.domain, - edu_type=edu_type, + edu_type="m.presence_accept", content={ "observed_user": observed_user.to_string(), "observer_user": observer_user.to_string(), } ) + state_dict = yield self.get_state(observed_user, as_event=False) + + self.federation.send_edu( + destination=observer_user.domain, + edu_type="m.presence", + content={ + "push": [state_dict] + } + ) + @defer.inlineCallbacks def accept_presence(self, observed_user, observer_user): """Handles a m.presence_accept EDU. Mark a presence invite from a local or remote user as accepted in a local user's presence list. Starts polling for presence updates from the local or remote user. - Args: observed_user(UserID): The user to update in the presence list. observer_user(UserID): The owner of the presence list to update. @@ -558,15 +722,10 @@ class PresenceHandler(BaseHandler): observer_user.localpart, observed_user.to_string() ) - yield self.start_polling_presence( - observer_user, target_user=observed_user - ) - @defer.inlineCallbacks def deny_presence(self, observed_user, observer_user): """Handle a m.presence_deny EDU. Removes a local or remote user from a local user's presence list. - Args: observed_user(UserID): The local or remote user to remove from the list. @@ -584,7 +743,6 @@ class PresenceHandler(BaseHandler): def drop(self, observed_user, observer_user): """Remove a local or remote user from a local user's presence list and unsubscribe the local user from updates that user. - Args: observed_user(UserId): The local or remote user to remove from the list. @@ -599,710 +757,353 @@ class PresenceHandler(BaseHandler): observer_user.localpart, observed_user.to_string() ) - self.stop_polling_presence( - observer_user, target_user=observed_user - ) - - @defer.inlineCallbacks - def get_presence_list(self, observer_user, accepted=None): - """Get the presence list for a local user. The retured list includes - the current presence state for each user listed. - - Args: - observer_user(UserID): The local user whose presence list to fetch. - accepted(bool or None): If not none then only include users who - have or have not accepted the presence invite request. - Returns: - A Deferred list of presence state events. - """ - if not self.hs.is_mine(observer_user): - raise SynapseError(400, "User is not hosted on this Home Server") - - presence_list = yield self.store.get_presence_list( - observer_user.localpart, accepted=accepted - ) - - results = [] - for row in presence_list: - observed_user = UserID.from_string(row["observed_user_id"]) - result = { - "observed_user": observed_user, "accepted": row["accepted"] - } - result.update( - self._get_or_offline_usercache(observed_user).get_state() - ) - if "last_active" in result: - result["last_active_ago"] = int( - self.clock.time_msec() - result.pop("last_active") - ) - results.append(result) - - defer.returnValue(results) + # TODO: Inform the remote that we've dropped the presence list. @defer.inlineCallbacks - @log_function - def start_polling_presence(self, user, target_user=None, state=None): - """Subscribe a local user to presence updates from a local or remote - user. If no target_user is supplied then subscribe to all users stored - in the presence list for the local user. - - Additonally this pushes the current presence state of this user to all - target_users. That state can be provided directly or will be read from - the stored state for the local user. - - Also this attempts to notify the local user of the current state of - any local target users. - - Args: - user(UserID): The local user that whishes for presence updates. - target_user(UserID): The local or remote user whose updates are - wanted. - state(dict): Optional presence state for the local user. + def is_visible(self, observed_user, observer_user): + """Returns whether a user can see another user's presence. """ - logger.debug("Start polling for presence from %s", user) - - if target_user: - target_users = set([target_user]) - room_ids = [] - else: - presence = yield self.store.get_presence_list( - user.localpart, accepted=True - ) - target_users = set([ - UserID.from_string(x["observed_user_id"]) for x in presence - ]) + observer_rooms = yield self.store.get_rooms_for_user(observer_user.to_string()) + observed_rooms = yield self.store.get_rooms_for_user(observed_user.to_string()) - # Also include people in all my rooms + observer_room_ids = set(r.room_id for r in observer_rooms) + observed_room_ids = set(r.room_id for r in observed_rooms) - room_ids = yield self.get_joined_rooms_for_user(user) + if observer_room_ids & observed_room_ids: + defer.returnValue(True) - if state is None: - state = yield self.store.get_presence_state(user.localpart) - else: - # statuscache = self._get_or_make_usercache(user) - # self._user_cachemap_latest_serial += 1 - # statuscache.update(state, self._user_cachemap_latest_serial) - pass - - yield self.push_update_to_local_and_remote( - observed_user=user, - users_to_push=target_users, - room_ids=room_ids, - statuscache=self._get_or_make_usercache(user), + accepted_observers = yield self.store.get_presence_list_observers_accepted( + observed_user.to_string() ) - for target_user in target_users: - if self.hs.is_mine(target_user): - self._start_polling_local(user, target_user) - - # We want to tell the person that just came online - # presence state of people they are interested in? - self.push_update_to_clients( - users_to_push=[user], - ) - - deferreds = [] - remote_users = [u for u in target_users if not self.hs.is_mine(u)] - remoteusers_by_domain = partition(remote_users, lambda u: u.domain) - # Only poll for people in our get_presence_list - for domain in remoteusers_by_domain: - remoteusers = remoteusers_by_domain[domain] - - deferreds.append(self._start_polling_remote( - user, domain, remoteusers - )) - - yield defer.DeferredList(deferreds, consumeErrors=True) + defer.returnValue(observer_user.to_string() in accepted_observers) - def _start_polling_local(self, user, target_user): - """Subscribe a local user to presence updates for a local user - - Args: - user(UserId): The local user that wishes for updates. - target_user(UserId): The local users whose updates are wanted. + @defer.inlineCallbacks + def get_all_presence_updates(self, last_id, current_id): """ - target_localpart = target_user.localpart - - if target_localpart not in self._local_pushmap: - self._local_pushmap[target_localpart] = set() - - self._local_pushmap[target_localpart].add(user) - - def _start_polling_remote(self, user, domain, remoteusers): - """Subscribe a local user to presence updates for remote users on a - given remote domain. - - Args: - user(UserID): The local user that wishes for updates. - domain(str): The remote server the local user wants updates from. - remoteusers(UserID): The remote users that local user wants to be - told about. - Returns: - A Deferred. + Gets a list of presence update rows from between the given stream ids. + Each row has: + - stream_id(str) + - user_id(str) + - state(str) + - last_active_ts(int) + - last_federation_update_ts(int) + - last_user_sync_ts(int) + - status_msg(int) + - currently_active(int) """ - to_poll = set() - - for u in remoteusers: - if u not in self._remote_recvmap: - self._remote_recvmap[u] = set() - to_poll.add(u) - - self._remote_recvmap[u].add(user) - - if not to_poll: - return defer.succeed(None) + # TODO(markjh): replicate the unpersisted changes. + # This could use the in-memory stores for recent changes. + rows = yield self.store.get_all_presence_updates(last_id, current_id) + defer.returnValue(rows) - return self.federation.send_edu( - destination=domain, - edu_type="m.presence", - content={"poll": [u.to_string() for u in to_poll]} - ) - - @log_function - def stop_polling_presence(self, user, target_user=None): - """Unsubscribe a local user from presence updates from a local or - remote user. If no target user is supplied then unsubscribe the user - from all presence updates that the user had subscribed to. - - Args: - user(UserID): The local user that no longer wishes for updates. - target_user(UserID or None): The user whose updates are no longer - wanted. - Returns: - A Deferred. - """ - logger.debug("Stop polling for presence from %s", user) - if not target_user or self.hs.is_mine(target_user): - self._stop_polling_local(user, target_user=target_user) +def should_notify(old_state, new_state): + """Decides if a presence state change should be sent to interested parties. + """ + if old_state.status_msg != new_state.status_msg: + return True - deferreds = [] + if old_state.state == PresenceState.ONLINE: + if new_state.state != PresenceState.ONLINE: + # Always notify for online -> anything + return True - if target_user: - if target_user not in self._remote_recvmap: - return - target_users = set([target_user]) - else: - target_users = self._remote_recvmap.keys() + if new_state.currently_active != old_state.currently_active: + return True - remoteusers = [u for u in target_users - if user in self._remote_recvmap[u]] - remoteusers_by_domain = partition(remoteusers, lambda u: u.domain) + if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: + # Always notify for a transition where last active gets bumped. + return True - for domain in remoteusers_by_domain: - remoteusers = remoteusers_by_domain[domain] + if old_state.state != new_state.state: + return True - deferreds.append( - self._stop_polling_remote(user, domain, remoteusers) - ) + return False - return defer.DeferredList(deferreds, consumeErrors=True) - def _stop_polling_local(self, user, target_user): - """Unsubscribe a local user from presence updates from a local user on - this server. +def _format_user_presence_state(state, now): + """Convert UserPresenceState to a format that can be sent down to clients + and to other servers. + """ + content = { + "presence": state.state, + "user_id": state.user_id, + } + if state.last_active_ts: + content["last_active_ago"] = now - state.last_active_ts + if state.status_msg and state.state != PresenceState.OFFLINE: + content["status_msg"] = state.status_msg + if state.state == PresenceState.ONLINE: + content["currently_active"] = state.currently_active - Args: - user(UserID): The local user that no longer wishes for updates. - target_user(UserID): The user whose updates are no longer wanted. - """ - for localpart in self._local_pushmap.keys(): - if target_user and localpart != target_user.localpart: - continue + return content - if user in self._local_pushmap[localpart]: - self._local_pushmap[localpart].remove(user) - if not self._local_pushmap[localpart]: - del self._local_pushmap[localpart] +class PresenceEventSource(object): + def __init__(self, hs): + self.hs = hs + self.clock = hs.get_clock() + self.store = hs.get_datastore() + @defer.inlineCallbacks @log_function - def _stop_polling_remote(self, user, domain, remoteusers): - """Unsubscribe a local user from presence updates from remote users on - a given domain. - - Args: - user(UserID): The local user that no longer wishes for updates. - domain(str): The remote server to unsubscribe from. - remoteusers([UserID]): The users on that remote server that the - local user no longer wishes to be updated about. - Returns: - A Deferred. - """ - to_unpoll = set() - - for u in remoteusers: - self._remote_recvmap[u].remove(user) - - if not self._remote_recvmap[u]: - del self._remote_recvmap[u] - to_unpoll.add(u) + def get_new_events(self, user, from_key, room_ids=None, include_offline=True, + **kwargs): + # The process for getting presence events are: + # 1. Get the rooms the user is in. + # 2. Get the list of user in the rooms. + # 3. Get the list of users that are in the user's presence list. + # 4. If there is a from_key set, cross reference the list of users + # with the `presence_stream_cache` to see which ones we actually + # need to check. + # 5. Load current state for the users. + # + # We don't try and limit the presence updates by the current token, as + # sending down the rare duplicate is not a concern. + + with Measure(self.clock, "presence.get_new_events"): + user_id = user.to_string() + if from_key is not None: + from_key = int(from_key) + room_ids = room_ids or [] - if not to_unpoll: - return defer.succeed(None) + presence = self.hs.get_handlers().presence_handler + stream_change_cache = self.store.presence_stream_cache - return self.federation.send_edu( - destination=domain, - edu_type="m.presence", - content={"unpoll": [u.to_string() for u in to_unpoll]} - ) + if not room_ids: + rooms = yield self.store.get_rooms_for_user(user_id) + room_ids = set(e.room_id for e in rooms) + else: + room_ids = set(room_ids) + + max_token = self.store.get_current_presence_token() + + plist = yield self.store.get_presence_list_accepted(user.localpart) + friends = set(row["observed_user_id"] for row in plist) + friends.add(user_id) # So that we receive our own presence + + user_ids_changed = set() + changed = None + if from_key and max_token - from_key < 100: + # For small deltas, its quicker to get all changes and then + # work out if we share a room or they're in our presence list + changed = stream_change_cache.get_all_entities_changed(from_key) + + # get_all_entities_changed can return None + if changed is not None: + for other_user_id in changed: + if other_user_id in friends: + user_ids_changed.add(other_user_id) + continue + other_rooms = yield self.store.get_rooms_for_user(other_user_id) + if room_ids.intersection(e.room_id for e in other_rooms): + user_ids_changed.add(other_user_id) + continue + else: + # Too many possible updates. Find all users we can see and check + # if any of them have changed. + user_ids_to_check = set() + for room_id in room_ids: + users = yield self.store.get_users_in_room(room_id) + user_ids_to_check.update(users) + + user_ids_to_check.update(friends) + + # Always include yourself. Only really matters for when the user is + # not in any rooms, but still. + user_ids_to_check.add(user_id) + + if from_key: + user_ids_changed = stream_change_cache.get_entities_changed( + user_ids_to_check, from_key, + ) + else: + user_ids_changed = user_ids_to_check - @defer.inlineCallbacks - @log_function - def push_presence(self, user, statuscache): - """ - Notify local and remote users of a change in presence of a local user. - Pushes the update to local clients and remote domains that are directly - subscribed to the presence of the local user. - Also pushes that update to any local user or remote domain that shares - a room with the local user. + updates = yield presence.current_state_for_users(user_ids_changed) - Args: - user(UserID): The local user whose presence was updated. - statuscache(UserPresenceCache): Cache of the user's presence state - Returns: - A Deferred. - """ - assert(self.hs.is_mine(user)) + now = self.clock.time_msec() - logger.debug("Pushing presence update from %s", user) + defer.returnValue(([ + { + "type": "m.presence", + "content": _format_user_presence_state(s, now), + } + for s in updates.values() + if include_offline or s.state != PresenceState.OFFLINE + ], max_token)) - localusers = set(self._local_pushmap.get(user.localpart, set())) - remotedomains = set(self._remote_sendmap.get(user.localpart, set())) + def get_current_key(self): + return self.store.get_current_presence_token() - # Reflect users' status changes back to themselves, so UIs look nice - # and also user is informed of server-forced pushes - localusers.add(user) + def get_pagination_rows(self, user, pagination_config, key): + return self.get_new_events(user, from_key=None, include_offline=False) - room_ids = yield self.get_joined_rooms_for_user(user) - if not localusers and not room_ids: - defer.returnValue(None) +def handle_timeouts(user_states, is_mine_fn, user_to_num_current_syncs, now): + """Checks the presence of users that have timed out and updates as + appropriate. - yield self.push_update_to_local_and_remote( - observed_user=user, - users_to_push=localusers, - remote_domains=remotedomains, - room_ids=room_ids, - statuscache=statuscache, - ) - yield user_presence_changed(self.distributor, user, statuscache) + Args: + user_states(list): List of UserPresenceState's to check. + is_mine_fn (fn): Function that returns if a user_id is ours + user_to_num_current_syncs (dict): Mapping of user_id to number of currently + active syncs. + now (int): Current time in ms. - @defer.inlineCallbacks - def incoming_presence(self, origin, content): - """Handle an incoming m.presence EDU. - For each presence update in the "push" list update our local cache and - notify the appropriate local clients. Only clients that share a room - or are directly subscribed to the presence for a user should be - notified of the update. - For each subscription request in the "poll" list start pushing presence - updates to the remote server. - For unsubscribe request in the "unpoll" list stop pushing presence - updates to the remote server. + Returns: + List of UserPresenceState updates + """ + changes = {} # Actual changes we need to notify people about - Args: - orgin(str): The source of this m.presence EDU. - content(dict): The content of this m.presence EDU. - Returns: - A Deferred. - """ - deferreds = [] + for state in user_states: + is_mine = is_mine_fn(state.user_id) - for push in content.get("push", []): - user = UserID.from_string(push["user_id"]) + new_state = handle_timeout(state, is_mine, user_to_num_current_syncs, now) + if new_state: + changes[state.user_id] = new_state - logger.debug("Incoming presence update from %s", user) + return changes.values() - observers = set(self._remote_recvmap.get(user, set())) - if observers: - logger.debug( - " | %d interested local observers %r", len(observers), observers - ) - room_ids = yield self.get_joined_rooms_for_user(user) - if room_ids: - logger.debug(" | %d interested room IDs %r", len(room_ids), room_ids) +def handle_timeout(state, is_mine, user_to_num_current_syncs, now): + """Checks the presence of the user to see if any of the timers have elapsed - state = dict(push) - del state["user_id"] + Args: + state (UserPresenceState) + is_mine (bool): Whether the user is ours + user_to_num_current_syncs (dict): Mapping of user_id to number of currently + active syncs. + now (int): Current time in ms. - if "presence" not in state: - logger.warning( - "Received a presence 'push' EDU from %s without a" - " 'presence' key", origin + Returns: + A UserPresenceState update or None if no update. + """ + if state.state == PresenceState.OFFLINE: + # No timeouts are associated with offline states. + return None + + changed = False + user_id = state.user_id + + if is_mine: + if state.state == PresenceState.ONLINE: + if now - state.last_active_ts > IDLE_TIMER: + # Currently online, but last activity ages ago so auto + # idle + state = state.copy_and_replace( + state=PresenceState.UNAVAILABLE, ) - continue - - if "last_active_ago" in state: - state["last_active"] = int( - self.clock.time_msec() - state.pop("last_active_ago") + changed = True + elif now - state.last_active_ts > LAST_ACTIVE_GRANULARITY: + # So that we send down a notification that we've + # stopped updating. + changed = True + + if now - state.last_federation_update_ts > FEDERATION_PING_INTERVAL: + # Need to send ping to other servers to ensure they don't + # timeout and set us to offline + changed = True + + # If there are have been no sync for a while (and none ongoing), + # set presence to offline + if not user_to_num_current_syncs.get(user_id, 0): + if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT: + state = state.copy_and_replace( + state=PresenceState.OFFLINE, + status_msg=None, ) - - self._user_cachemap_latest_serial += 1 - yield self.update_presence_cache(user, state, room_ids=room_ids) - - if not observers and not room_ids: - logger.debug(" | no interested observers or room IDs") - continue - - self.push_update_to_clients( - users_to_push=observers, room_ids=room_ids + changed = True + else: + # We expect to be poked occaisonally by the other side. + # This is to protect against forgetful/buggy servers, so that + # no one gets stuck online forever. + if now - state.last_federation_update_ts > FEDERATION_TIMEOUT: + # The other side seems to have disappeared. + state = state.copy_and_replace( + state=PresenceState.OFFLINE, + status_msg=None, ) + changed = True - user_id = user.to_string() - - if state["presence"] == PresenceState.OFFLINE: - self._remote_offline_serials.insert( - 0, - (self._user_cachemap_latest_serial, set([user_id])) - ) - while len(self._remote_offline_serials) > MAX_OFFLINE_SERIALS: - self._remote_offline_serials.pop() # remove the oldest - if user in self._user_cachemap: - del self._user_cachemap[user] - else: - # Remove the user from remote_offline_serials now that they're - # no longer offline - for idx, elem in enumerate(self._remote_offline_serials): - (_, user_ids) = elem - user_ids.discard(user_id) - if not user_ids: - self._remote_offline_serials.pop(idx) - - for poll in content.get("poll", []): - user = UserID.from_string(poll) - - if not self.hs.is_mine(user): - continue + return state if changed else None - # TODO(paul) permissions checks - - if user not in self._remote_sendmap: - self._remote_sendmap[user] = set() - - self._remote_sendmap[user].add(origin) - - deferreds.append(self._push_presence_remote(user, origin)) - - for unpoll in content.get("unpoll", []): - user = UserID.from_string(unpoll) - - if not self.hs.is_mine(user): - continue - if user in self._remote_sendmap: - self._remote_sendmap[user].remove(origin) +def handle_update(prev_state, new_state, is_mine, wheel_timer, now): + """Given a presence update: + 1. Add any appropriate timers. + 2. Check if we should notify anyone. - if not self._remote_sendmap[user]: - del self._remote_sendmap[user] + Args: + prev_state (UserPresenceState) + new_state (UserPresenceState) + is_mine (bool): Whether the user is ours + wheel_timer (WheelTimer) + now (int): Time now in ms - yield defer.DeferredList(deferreds, consumeErrors=True) - - @defer.inlineCallbacks - def update_presence_cache(self, user, state={}, room_ids=None, - add_to_cache=True): - """Update the presence cache for a user with a new state and bump the - serial to the latest value. - - Args: - user(UserID): The user being updated - state(dict): The presence state being updated - room_ids(None or list of str): A list of room_ids to update. If - room_ids is None then fetch the list of room_ids the user is - joined to. - add_to_cache: Whether to add an entry to the presence cache if the - user isn't already in the cache. - Returns: - A Deferred UserPresenceCache for the user being updated. - """ - if room_ids is None: - room_ids = yield self.get_joined_rooms_for_user(user) - - for room_id in room_ids: - self._room_serials[room_id] = self._user_cachemap_latest_serial - if add_to_cache: - statuscache = self._get_or_make_usercache(user) - else: - statuscache = self._get_or_offline_usercache(user) - statuscache.update(state, serial=self._user_cachemap_latest_serial) - defer.returnValue(statuscache) - - @defer.inlineCallbacks - def push_update_to_local_and_remote(self, observed_user, statuscache, - users_to_push=[], room_ids=[], - remote_domains=[]): - """Notify local clients and remote servers of a change in the presence - of a user. - - Args: - observed_user(UserID): The user to push the presence state for. - statuscache(UserPresenceCache): The cache for the presence state to - push. - users_to_push([UserID]): A list of local and remote users to - notify. - room_ids([str]): Notify the local and remote occupants of these - rooms. - remote_domains([str]): A list of remote servers to notify in - addition to those implied by the users_to_push and the - room_ids. - Returns: - A Deferred. - """ - - localusers, remoteusers = partitionbool( - users_to_push, - lambda u: self.hs.is_mine(u) - ) - - localusers = set(localusers) - - self.push_update_to_clients( - users_to_push=localusers, room_ids=room_ids - ) - - remote_domains = set(remote_domains) - remote_domains |= set([r.domain for r in remoteusers]) - for room_id in room_ids: - remote_domains.update( - (yield self.store.get_joined_hosts_for_room(room_id)) - ) - - remote_domains.discard(self.hs.hostname) - - deferreds = [] - for domain in remote_domains: - logger.debug(" | push to remote domain %s", domain) - deferreds.append( - self._push_presence_remote( - observed_user, domain, state=statuscache.get_state() - ) + Returns: + 3-tuple: `(new_state, persist_and_notify, federation_ping)` where: + - new_state: is the state to actually persist + - persist_and_notify (bool): whether to persist and notify people + - federation_ping (bool): whether we should send a ping over federation + """ + user_id = new_state.user_id + + persist_and_notify = False + federation_ping = False + + # If the users are ours then we want to set up a bunch of timers + # to time things out. + if is_mine: + if new_state.state == PresenceState.ONLINE: + # Idle timer + wheel_timer.insert( + now=now, + obj=user_id, + then=new_state.last_active_ts + IDLE_TIMER ) - yield defer.DeferredList(deferreds, consumeErrors=True) - - defer.returnValue((localusers, remote_domains)) - - def push_update_to_clients(self, users_to_push=[], room_ids=[]): - """Notify clients of a new presence event. - - Args: - users_to_push([UserID]): List of users to notify. - room_ids([str]): List of room_ids to notify. - """ - with PreserveLoggingContext(): - self.notifier.on_new_event( - "presence_key", - self._user_cachemap_latest_serial, - users_to_push, - room_ids, + active = now - new_state.last_active_ts < LAST_ACTIVE_GRANULARITY + new_state = new_state.copy_and_replace( + currently_active=active, ) - @defer.inlineCallbacks - def _push_presence_remote(self, user, destination, state=None): - """Push a user's presence to a remote server. If a presence state event - that event is sent. Otherwise a new state event is constructed from the - stored presence state. - The last_active is replaced with last_active_ago in case the wallclock - time on the remote server is different to the time on this server. - Sends an EDU to the remote server with the current presence state. - - Args: - user(UserID): The user to push the presence state for. - destination(str): The remote server to send state to. - state(dict): The state to push, or None to use the current stored - state. - Returns: - A Deferred. - """ - if state is None: - state = yield self.store.get_presence_state(user.localpart) - del state["mtime"] - state["presence"] = state.pop("state") - - if user in self._user_cachemap: - state["last_active"] = ( - self._user_cachemap[user].get_state()["last_active"] + if active: + wheel_timer.insert( + now=now, + obj=user_id, + then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY ) - yield collect_presencelike_data(self.distributor, user, state) - - if "last_active" in state: - state = dict(state) - state["last_active_ago"] = int( - self.clock.time_msec() - state.pop("last_active") + if new_state.state != PresenceState.OFFLINE: + # User has stopped syncing + wheel_timer.insert( + now=now, + obj=user_id, + then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT ) - user_state = {"user_id": user.to_string(), } - user_state.update(state) - - yield self.federation.send_edu( - destination=destination, - edu_type="m.presence", - content={"push": [user_state, ], } - ) - - -class PresenceEventSource(object): - def __init__(self, hs): - self.hs = hs - self.clock = hs.get_clock() - - @defer.inlineCallbacks - @log_function - def get_new_events(self, user, from_key, room_ids=None, **kwargs): - from_key = int(from_key) - room_ids = room_ids or [] - - presence = self.hs.get_handlers().presence_handler - cachemap = presence._user_cachemap - - max_serial = presence._user_cachemap_latest_serial - - clock = self.clock - latest_serial = 0 - - user_ids_to_check = {user} - presence_list = yield presence.store.get_presence_list( - user.localpart, accepted=True - ) - if presence_list is not None: - user_ids_to_check |= set( - UserID.from_string(p["observed_user_id"]) for p in presence_list - ) - for room_id in set(room_ids) & set(presence._room_serials): - if presence._room_serials[room_id] > from_key: - joined = yield presence.get_joined_users_for_room_id(room_id) - user_ids_to_check |= set(joined) - - updates = [] - for observed_user in user_ids_to_check & set(cachemap): - cached = cachemap[observed_user] - - if cached.serial <= from_key or cached.serial > max_serial: - continue - - latest_serial = max(cached.serial, latest_serial) - updates.append(cached.make_event(user=observed_user, clock=clock)) - - # TODO(paul): limit - - for serial, user_ids in presence._remote_offline_serials: - if serial <= from_key: - break - - if serial > max_serial: - continue - - latest_serial = max(latest_serial, serial) - for u in user_ids: - updates.append({ - "type": "m.presence", - "content": {"user_id": u, "presence": PresenceState.OFFLINE}, - }) - # TODO(paul): For the v2 API we want to tell the client their from_key - # is too old if we fell off the end of the _remote_offline_serials - # list, and get them to invalidate+resync. In v1 we have no such - # concept so this is a best-effort result. - - if updates: - defer.returnValue((updates, latest_serial)) - else: - defer.returnValue(([], presence._user_cachemap_latest_serial)) - - def get_current_key(self): - presence = self.hs.get_handlers().presence_handler - return presence._user_cachemap_latest_serial - - @defer.inlineCallbacks - def get_pagination_rows(self, user, pagination_config, key): - # TODO (erikj): Does this make sense? Ordering? - - from_key = int(pagination_config.from_key) - - if pagination_config.to_key: - to_key = int(pagination_config.to_key) - else: - to_key = -1 - - presence = self.hs.get_handlers().presence_handler - cachemap = presence._user_cachemap + last_federate = new_state.last_federation_update_ts + if now - last_federate > FEDERATION_PING_INTERVAL: + # Been a while since we've poked remote servers + new_state = new_state.copy_and_replace( + last_federation_update_ts=now, + ) + federation_ping = True - user_ids_to_check = {user} - presence_list = yield presence.store.get_presence_list( - user.localpart, accepted=True + else: + wheel_timer.insert( + now=now, + obj=user_id, + then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT ) - if presence_list is not None: - user_ids_to_check |= set( - UserID.from_string(p["observed_user_id"]) for p in presence_list - ) - room_ids = yield presence.get_joined_rooms_for_user(user) - for room_id in set(room_ids) & set(presence._room_serials): - if presence._room_serials[room_id] >= from_key: - joined = yield presence.get_joined_users_for_room_id(room_id) - user_ids_to_check |= set(joined) - - updates = [] - for observed_user in user_ids_to_check & set(cachemap): - if not (to_key < cachemap[observed_user].serial <= from_key): - continue - - updates.append((observed_user, cachemap[observed_user])) - - # TODO(paul): limit - - if updates: - clock = self.clock - - earliest_serial = max([x[1].serial for x in updates]) - data = [x[1].make_event(user=x[0], clock=clock) for x in updates] - - defer.returnValue((data, earliest_serial)) - else: - defer.returnValue(([], 0)) - -class UserPresenceCache(object): - """Store an observed user's state and status message. - - Includes the update timestamp. - """ - def __init__(self): - self.state = {"presence": PresenceState.OFFLINE} - self.serial = None - - def __repr__(self): - return "UserPresenceCache(state=%r, serial=%r)" % ( - self.state, self.serial + # Check whether the change was something worth notifying about + if should_notify(prev_state, new_state): + new_state = new_state.copy_and_replace( + last_federation_update_ts=now, ) + persist_and_notify = True - def update(self, state, serial): - assert("mtime_age" not in state) - - self.state.update(state) - # Delete keys that are now 'None' - for k in self.state.keys(): - if self.state[k] is None: - del self.state[k] - - self.serial = serial - - if "status_msg" in state: - self.status_msg = state["status_msg"] - else: - self.status_msg = None - - def get_state(self): - # clone it so caller can't break our cache - state = dict(self.state) - return state - - def make_event(self, user, clock): - content = self.get_state() - content["user_id"] = user.to_string() - - if "last_active" in content: - content["last_active_ago"] = int( - clock.time_msec() - content.pop("last_active") - ) - - return {"type": "m.presence", "content": content} + return new_state, persist_and_notify, federation_ping diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 629e6e3594..c9ad5944e6 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -16,8 +16,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, AuthError, CodeMessageException -from synapse.api.constants import EventTypes, Membership -from synapse.types import UserID +from synapse.types import UserID, Requester from synapse.util import unwrapFirstError from ._base import BaseHandler @@ -49,6 +48,9 @@ class ProfileHandler(BaseHandler): distributor = hs.get_distributor() self.distributor = distributor + distributor.declare("collect_presencelike_data") + distributor.declare("changed_presencelike_data") + distributor.observe("registered_user", self.registered_user) distributor.observe( @@ -208,21 +210,18 @@ class ProfileHandler(BaseHandler): ) for j in joins: - content = { - "membership": Membership.JOIN, - } - - yield collect_presencelike_data(self.distributor, user, content) - - msg_handler = self.hs.get_handlers().message_handler + handler = self.hs.get_handlers().room_member_handler try: - yield msg_handler.create_and_send_event({ - "type": EventTypes.Member, - "room_id": j.room_id, - "state_key": user.to_string(), - "content": content, - "sender": user.to_string() - }, ratelimit=False) + # Assume the user isn't a guest because we don't let guests set + # profile or avatar data. + requester = Requester(user, "", False) + yield handler.update_membership( + requester, + user, + j.room_id, + "join", # We treat a profile update like a join. + ratelimit=False, # Try to hide that these events aren't atomic. + ) except Exception as e: logger.warn( "Failed to update join event for room %s - %s", diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index de4c694714..935c339707 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -36,8 +36,6 @@ class ReceiptsHandler(BaseHandler): ) self.clock = self.hs.get_clock() - self._receipt_cache = None - @defer.inlineCallbacks def received_client_receipt(self, room_id, receipt_type, user_id, event_id): diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 24c850ae9b..6d155d57e7 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -60,7 +60,7 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - yield self.check_user_id_is_valid(user_id) + yield self.check_user_id_not_appservice_exclusive(user_id) users = yield self.store.get_users_by_id_case_insensitive(user_id) if users: @@ -145,7 +145,7 @@ class RegistrationHandler(BaseHandler): localpart = yield self._generate_user_id(attempts > 0) user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - yield self.check_user_id_is_valid(user_id) + yield self.check_user_id_not_appservice_exclusive(user_id) if generate_token: token = self.auth_handler().generate_access_token(user_id) try: @@ -180,6 +180,11 @@ class RegistrationHandler(BaseHandler): 400, "Invalid user localpart for this application service.", errcode=Codes.EXCLUSIVE ) + + yield self.check_user_id_not_appservice_exclusive( + user_id, allowed_appservice=service + ) + token = self.auth_handler().generate_access_token(user_id) yield self.store.register( user_id=user_id, @@ -226,7 +231,7 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - yield self.check_user_id_is_valid(user_id) + yield self.check_user_id_not_appservice_exclusive(user_id) token = self.auth_handler().generate_access_token(user_id) try: yield self.store.register( @@ -278,12 +283,14 @@ class RegistrationHandler(BaseHandler): yield identity_handler.bind_threepid(c, user_id) @defer.inlineCallbacks - def check_user_id_is_valid(self, user_id): + def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None): # valid user IDs must not clash with any user ID namespaces claimed by # application services. services = yield self.store.get_app_services() interested_services = [ - s for s in services if s.is_interested_in_user(user_id) + s for s in services + if s.is_interested_in_user(user_id) + and s != allowed_appservice ] for service in interested_services: if service.is_exclusive_user(user_id): @@ -342,3 +349,18 @@ class RegistrationHandler(BaseHandler): def auth_handler(self): return self.hs.get_handlers().auth_handler + + @defer.inlineCallbacks + def guest_access_token_for(self, medium, address, inviter_user_id): + access_token = yield self.store.get_3pid_guest_access_token(medium, address) + if access_token: + defer.returnValue(access_token) + + _, access_token = yield self.register( + generate_token=True, + make_guest=True + ) + access_token = yield self.store.save_or_get_3pid_guest_access_token( + medium, address, access_token, inviter_user_id + ) + defer.returnValue(access_token) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index b2de2cd0c0..d2de23a6cc 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -24,7 +24,6 @@ from synapse.api.constants import ( ) from synapse.api.errors import AuthError, StoreError, SynapseError, Codes from synapse.util import stringutils, unwrapFirstError -from synapse.util.async import run_on_reactor from synapse.util.logcontext import preserve_context_over_fn from signedjson.sign import verify_signed_json @@ -42,10 +41,6 @@ logger = logging.getLogger(__name__) id_server_scheme = "https://" -def collect_presencelike_data(distributor, user, content): - return distributor.fire("collect_presencelike_data", user, content) - - def user_left_room(distributor, user, room_id): return preserve_context_over_fn( distributor.fire, @@ -81,20 +76,20 @@ class RoomCreationHandler(BaseHandler): } @defer.inlineCallbacks - def create_room(self, user_id, room_id, config): + def create_room(self, requester, config): """ Creates a new room. Args: - user_id (str): The ID of the user creating the new room. - room_id (str): The proposed ID for the new room. Can be None, in - which case one will be created for you. + requester (Requester): The user who requested the room creation. config (dict) : A dict of configuration options. Returns: The new room ID. Raises: - SynapseError if the room ID was taken, couldn't be stored, or - something went horribly wrong. + SynapseError if the room ID couldn't be stored, or something went + horribly wrong. """ + user_id = requester.user.to_string() + self.ratelimit(user_id) if "room_alias_name" in config: @@ -126,40 +121,28 @@ class RoomCreationHandler(BaseHandler): is_public = config.get("visibility", None) == "public" - if room_id: - # Ensure room_id is the correct type - room_id_obj = RoomID.from_string(room_id) - if not self.hs.is_mine(room_id_obj): - raise SynapseError(400, "Room id must be local") - - yield self.store.store_room( - room_id=room_id, - room_creator_user_id=user_id, - is_public=is_public - ) - else: - # autogen room IDs and try to create it. We may clash, so just - # try a few times till one goes through, giving up eventually. - attempts = 0 - room_id = None - while attempts < 5: - try: - random_string = stringutils.random_string(18) - gen_room_id = RoomID.create( - random_string, - self.hs.hostname, - ) - yield self.store.store_room( - room_id=gen_room_id.to_string(), - room_creator_user_id=user_id, - is_public=is_public - ) - room_id = gen_room_id.to_string() - break - except StoreError: - attempts += 1 - if not room_id: - raise StoreError(500, "Couldn't generate a room ID.") + # autogen room IDs and try to create it. We may clash, so just + # try a few times till one goes through, giving up eventually. + attempts = 0 + room_id = None + while attempts < 5: + try: + random_string = stringutils.random_string(18) + gen_room_id = RoomID.create( + random_string, + self.hs.hostname, + ) + yield self.store.store_room( + room_id=gen_room_id.to_string(), + room_creator_user_id=user_id, + is_public=is_public + ) + room_id = gen_room_id.to_string() + break + except StoreError: + attempts += 1 + if not room_id: + raise StoreError(500, "Couldn't generate a room ID.") if room_alias: directory_handler = self.hs.get_handlers().directory_handler @@ -185,9 +168,14 @@ class RoomCreationHandler(BaseHandler): creation_content = config.get("creation_content", {}) - user = UserID.from_string(user_id) - creation_events = self._create_events_for_new_room( - user, room_id, + msg_handler = self.hs.get_handlers().message_handler + room_member_handler = self.hs.get_handlers().room_member_handler + + yield self._send_events_for_new_room( + requester, + room_id, + msg_handler, + room_member_handler, preset_config=preset_config, invite_list=invite_list, initial_state=initial_state, @@ -195,14 +183,9 @@ class RoomCreationHandler(BaseHandler): room_alias=room_alias, ) - msg_handler = self.hs.get_handlers().message_handler - - for event in creation_events: - yield msg_handler.create_and_send_event(event, ratelimit=False) - if "name" in config: name = config["name"] - yield msg_handler.create_and_send_event({ + yield msg_handler.create_and_send_nonmember_event({ "type": EventTypes.Name, "room_id": room_id, "sender": user_id, @@ -212,7 +195,7 @@ class RoomCreationHandler(BaseHandler): if "topic" in config: topic = config["topic"] - yield msg_handler.create_and_send_event({ + yield msg_handler.create_and_send_nonmember_event({ "type": EventTypes.Topic, "room_id": room_id, "sender": user_id, @@ -221,13 +204,13 @@ class RoomCreationHandler(BaseHandler): }, ratelimit=False) for invitee in invite_list: - yield msg_handler.create_and_send_event({ - "type": EventTypes.Member, - "state_key": invitee, - "room_id": room_id, - "sender": user_id, - "content": {"membership": Membership.INVITE}, - }, ratelimit=False) + room_member_handler.update_membership( + requester, + UserID.from_string(invitee), + room_id, + "invite", + ratelimit=False, + ) for invite_3pid in invite_3pid_list: id_server = invite_3pid["id_server"] @@ -235,11 +218,11 @@ class RoomCreationHandler(BaseHandler): medium = invite_3pid["medium"] yield self.hs.get_handlers().room_member_handler.do_3pid_invite( room_id, - user, + requester.user, medium, address, id_server, - token_id=None, + requester, txn_id=None, ) @@ -253,19 +236,19 @@ class RoomCreationHandler(BaseHandler): defer.returnValue(result) - def _create_events_for_new_room(self, creator, room_id, preset_config, - invite_list, initial_state, creation_content, - room_alias): - config = RoomCreationHandler.PRESETS_DICT[preset_config] - - creator_id = creator.to_string() - - event_keys = { - "room_id": room_id, - "sender": creator_id, - "state_key": "", - } - + @defer.inlineCallbacks + def _send_events_for_new_room( + self, + creator, # A Requester object. + room_id, + msg_handler, + room_member_handler, + preset_config, + invite_list, + initial_state, + creation_content, + room_alias + ): def create(etype, content, **kwargs): e = { "type": etype, @@ -277,26 +260,39 @@ class RoomCreationHandler(BaseHandler): return e - creation_content.update({"creator": creator.to_string()}) - creation_event = create( + @defer.inlineCallbacks + def send(etype, content, **kwargs): + event = create(etype, content, **kwargs) + yield msg_handler.create_and_send_nonmember_event(event, ratelimit=False) + + config = RoomCreationHandler.PRESETS_DICT[preset_config] + + creator_id = creator.user.to_string() + + event_keys = { + "room_id": room_id, + "sender": creator_id, + "state_key": "", + } + + creation_content.update({"creator": creator_id}) + yield send( etype=EventTypes.Create, content=creation_content, ) - join_event = create( - etype=EventTypes.Member, - state_key=creator_id, - content={ - "membership": Membership.JOIN, - }, + yield room_member_handler.update_membership( + creator, + creator.user, + room_id, + "join", + ratelimit=False, ) - returned_events = [creation_event, join_event] - if (EventTypes.PowerLevels, '') not in initial_state: power_level_content = { "users": { - creator.to_string(): 100, + creator_id: 100, }, "users_default": 0, "events": { @@ -318,45 +314,35 @@ class RoomCreationHandler(BaseHandler): for invitee in invite_list: power_level_content["users"][invitee] = 100 - power_levels_event = create( + yield send( etype=EventTypes.PowerLevels, content=power_level_content, ) - returned_events.append(power_levels_event) - if room_alias and (EventTypes.CanonicalAlias, '') not in initial_state: - room_alias_event = create( + yield send( etype=EventTypes.CanonicalAlias, content={"alias": room_alias.to_string()}, ) - returned_events.append(room_alias_event) - if (EventTypes.JoinRules, '') not in initial_state: - join_rules_event = create( + yield send( etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]}, ) - returned_events.append(join_rules_event) - if (EventTypes.RoomHistoryVisibility, '') not in initial_state: - history_event = create( + yield send( etype=EventTypes.RoomHistoryVisibility, content={"history_visibility": config["history_visibility"]} ) - returned_events.append(history_event) - for (etype, state_key), content in initial_state.items(): - returned_events.append(create( + yield send( etype=etype, state_key=state_key, content=content, - )) - - return returned_events + ) class RoomMemberHandler(BaseHandler): @@ -404,16 +390,35 @@ class RoomMemberHandler(BaseHandler): remotedomains.add(member.domain) @defer.inlineCallbacks - def update_membership(self, requester, target, room_id, action, txn_id=None): + def update_membership( + self, + requester, + target, + room_id, + action, + txn_id=None, + remote_room_hosts=None, + third_party_signed=None, + ratelimit=True, + ): effective_membership_state = action if action in ["kick", "unban"]: effective_membership_state = "leave" elif action == "forget": effective_membership_state = "leave" + if third_party_signed is not None: + replication = self.hs.get_replication_layer() + yield replication.exchange_third_party_invite( + third_party_signed["sender"], + target.to_string(), + room_id, + third_party_signed, + ) + msg_handler = self.hs.get_handlers().message_handler - content = {"membership": unicode(effective_membership_state)} + content = {"membership": effective_membership_state} if requester.is_guest: content["kind"] = "guest" @@ -424,6 +429,9 @@ class RoomMemberHandler(BaseHandler): "room_id": room_id, "sender": requester.user.to_string(), "state_key": target.to_string(), + + # For backwards compatibility: + "membership": effective_membership_state, }, token_id=requester.access_token_id, txn_id=txn_id, @@ -444,202 +452,181 @@ class RoomMemberHandler(BaseHandler): errcode=Codes.BAD_STATE ) - yield msg_handler.send_event( + member_handler = self.hs.get_handlers().room_member_handler + yield member_handler.send_membership_event( event, context, - ratelimit=True, - is_guest=requester.is_guest + is_guest=requester.is_guest, + ratelimit=ratelimit, + remote_room_hosts=remote_room_hosts, + from_client=True, ) if action == "forget": yield self.forget(requester.user, room_id) @defer.inlineCallbacks - def send_membership_event(self, event, context, is_guest=False): - """ Change the membership status of a user in a room. + def send_membership_event( + self, + event, + context, + is_guest=False, + remote_room_hosts=None, + ratelimit=True, + from_client=True, + ): + """ + Change the membership status of a user in a room. Args: - event (SynapseEvent): The membership event + event (SynapseEvent): The membership event. + context: The context of the event. + is_guest (bool): Whether the sender is a guest. + room_hosts ([str]): Homeservers which are likely to already be in + the room, and could be danced with in order to join this + homeserver for the first time. + ratelimit (bool): Whether to rate limit this request. + from_client (bool): Whether this request is the result of a local + client request (rather than over federation). If so, we will + perform extra checks, like that this homeserver can act as this + client. Raises: SynapseError if there was a problem changing the membership. """ - target_user_id = event.state_key + target_user = UserID.from_string(event.state_key) + room_id = event.room_id - prev_state = context.current_state.get( - (EventTypes.Member, target_user_id), - None - ) + if from_client: + sender = UserID.from_string(event.sender) + assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) - room_id = event.room_id + message_handler = self.hs.get_handlers().message_handler + prev_event = message_handler.deduplicate_state_event(event, context) + if prev_event is not None: + return - # If we're trying to join a room then we have to do this differently - # if this HS is not currently in the room, i.e. we have to do the - # invite/join dance. - if event.membership == Membership.JOIN: - if is_guest: - guest_access = context.current_state.get( - (EventTypes.GuestAccess, ""), - None - ) - is_guest_access_allowed = ( - guest_access - and guest_access.content - and "guest_access" in guest_access.content - and guest_access.content["guest_access"] == "can_join" - ) - if not is_guest_access_allowed: - raise AuthError(403, "Guest access not allowed") + action = "send" - yield self._do_join(event, context) - else: - if event.membership == Membership.LEAVE: - is_host_in_room = yield self.is_host_in_room(room_id, context) - if not is_host_in_room: - # Rejecting an invite, rather than leaving a joined room - handler = self.hs.get_handlers().federation_handler - inviter = yield self.get_inviter(event) - if not inviter: - # return the same error as join_room_alias does - raise SynapseError(404, "No known servers") - yield handler.do_remotely_reject_invite( - [inviter.domain], - room_id, - event.user_id - ) - defer.returnValue({"room_id": room_id}) - return - - # FIXME: This isn't idempotency. - if prev_state and prev_state.membership == event.membership: - # double same action, treat this event as a NOOP. - defer.returnValue({}) - return - - yield self._do_local_membership_update( - event, - context=context, + if event.membership == Membership.JOIN: + if is_guest and not self._can_guest_join(context.current_state): + # This should be an auth check, but guests are a local concept, + # so don't really fit into the general auth process. + raise AuthError(403, "Guest access not allowed") + do_remote_join_dance, remote_room_hosts = self._should_do_dance( + context, + (self.get_inviter(event.state_key, context.current_state)), + remote_room_hosts, ) + if do_remote_join_dance: + action = "remote_join" + elif event.membership == Membership.LEAVE: + is_host_in_room = self.is_host_in_room(context.current_state) + if not is_host_in_room: + action = "remote_reject" - if prev_state and prev_state.membership == Membership.JOIN: - user = UserID.from_string(event.user_id) - user_left_room(self.distributor, user, event.room_id) + federation_handler = self.hs.get_handlers().federation_handler - defer.returnValue({"room_id": room_id}) + if action == "remote_join": + if len(remote_room_hosts) == 0: + raise SynapseError(404, "No known servers") - @defer.inlineCallbacks - def join_room_alias(self, joinee, room_alias, content={}): - directory_handler = self.hs.get_handlers().directory_handler - mapping = yield directory_handler.get_association(room_alias) + # We don't do an auth check if we are doing an invite + # join dance for now, since we're kinda implicitly checking + # that we are allowed to join when we decide whether or not we + # need to do the invite/join dance. + yield federation_handler.do_invite_join( + remote_room_hosts, + event.room_id, + event.user_id, + event.content, + ) + elif action == "remote_reject": + inviter = self.get_inviter(target_user.to_string(), context.current_state) + if not inviter: + raise SynapseError(404, "No known servers") + yield federation_handler.do_remotely_reject_invite( + [inviter.domain], + room_id, + event.user_id + ) + else: + yield self.handle_new_client_event( + event, + context, + extra_users=[target_user], + ratelimit=ratelimit, + ) - if not mapping: - raise SynapseError(404, "No such room alias") + prev_member_event = context.current_state.get( + (EventTypes.Member, target_user.to_string()), + None + ) - room_id = mapping["room_id"] - hosts = mapping["servers"] - if not hosts: - raise SynapseError(404, "No known servers") + if event.membership == Membership.JOIN: + if not prev_member_event or prev_member_event.membership != Membership.JOIN: + # Only fire user_joined_room if the user has acutally joined the + # room. Don't bother if the user is just changing their profile + # info. + yield user_joined_room(self.distributor, target_user, room_id) + elif event.membership == Membership.LEAVE: + if prev_member_event and prev_member_event.membership == Membership.JOIN: + user_left_room(self.distributor, target_user, room_id) + + def _can_guest_join(self, current_state): + """ + Returns whether a guest can join a room based on its current state. + """ + guest_access = current_state.get((EventTypes.GuestAccess, ""), None) + return ( + guest_access + and guest_access.content + and "guest_access" in guest_access.content + and guest_access.content["guest_access"] == "can_join" + ) - # If event doesn't include a display name, add one. - yield collect_presencelike_data(self.distributor, joinee, content) + def _should_do_dance(self, context, inviter, room_hosts=None): + # TODO: Shouldn't this be remote_room_host? + room_hosts = room_hosts or [] - content.update({"membership": Membership.JOIN}) - builder = self.event_builder_factory.new({ - "type": EventTypes.Member, - "state_key": joinee.to_string(), - "room_id": room_id, - "sender": joinee.to_string(), - "membership": Membership.JOIN, - "content": content, - }) - event, context = yield self._create_new_client_event(builder) + is_host_in_room = self.is_host_in_room(context.current_state) + if is_host_in_room: + return False, room_hosts - yield self._do_join(event, context, room_hosts=hosts) + if inviter and not self.hs.is_mine(inviter): + room_hosts.append(inviter.domain) - defer.returnValue({"room_id": room_id}) + return True, room_hosts @defer.inlineCallbacks - def _do_join(self, event, context, room_hosts=None): - room_id = event.room_id - - # XXX: We don't do an auth check if we are doing an invite - # join dance for now, since we're kinda implicitly checking - # that we are allowed to join when we decide whether or not we - # need to do the invite/join dance. - - is_host_in_room = yield self.is_host_in_room(room_id, context) - if is_host_in_room: - should_do_dance = False - elif room_hosts: # TODO: Shouldn't this be remote_room_host? - should_do_dance = True - else: - inviter = yield self.get_inviter(event) - if not inviter: - # return the same error as join_room_alias does - raise SynapseError(404, "No known servers") - should_do_dance = not self.hs.is_mine(inviter) - room_hosts = [inviter.domain] + def lookup_room_alias(self, room_alias): + """ + Get the room ID associated with a room alias. - if should_do_dance: - handler = self.hs.get_handlers().federation_handler - yield handler.do_invite_join( - room_hosts, - room_id, - event.user_id, - event.content, - ) - else: - logger.debug("Doing normal join") + Args: + room_alias (RoomAlias): The alias to look up. + Returns: + A tuple of: + The room ID as a RoomID object. + Hosts likely to be participating in the room ([str]). + Raises: + SynapseError if room alias could not be found. + """ + directory_handler = self.hs.get_handlers().directory_handler + mapping = yield directory_handler.get_association(room_alias) - yield self._do_local_membership_update( - event, - context=context, - ) + if not mapping: + raise SynapseError(404, "No such room alias") - prev_state = context.current_state.get((event.type, event.state_key)) - if not prev_state or prev_state.membership != Membership.JOIN: - # Only fire user_joined_room if the user has acutally joined the - # room. Don't bother if the user is just changing their profile - # info. - user = UserID.from_string(event.user_id) - yield user_joined_room(self.distributor, user, room_id) + room_id = mapping["room_id"] + servers = mapping["servers"] - @defer.inlineCallbacks - def get_inviter(self, event): - # TODO(markjh): get prev_state from snapshot - prev_state = yield self.store.get_room_member( - event.user_id, event.room_id - ) + defer.returnValue((RoomID.from_string(room_id), servers)) + def get_inviter(self, user_id, current_state): + prev_state = current_state.get((EventTypes.Member, user_id)) if prev_state and prev_state.membership == Membership.INVITE: - defer.returnValue(UserID.from_string(prev_state.user_id)) - return - elif "third_party_invite" in event.content: - if "sender" in event.content["third_party_invite"]: - inviter = UserID.from_string( - event.content["third_party_invite"]["sender"] - ) - defer.returnValue(inviter) - defer.returnValue(None) - - @defer.inlineCallbacks - def is_host_in_room(self, room_id, context): - is_host_in_room = yield self.auth.check_host_in_room( - room_id, - self.hs.hostname - ) - if not is_host_in_room: - # is *anyone* in the room? - room_member_keys = [ - v for (k, v) in context.current_state.keys() if ( - k == "m.room.member" - ) - ] - if len(room_member_keys) == 0: - # has the room been created so we can join it? - create_event = context.current_state.get(("m.room.create", "")) - if create_event: - is_host_in_room = True - defer.returnValue(is_host_in_room) + return UserID.from_string(prev_state.user_id) + return None @defer.inlineCallbacks def get_joined_rooms_for_user(self, user): @@ -657,18 +644,6 @@ class RoomMemberHandler(BaseHandler): defer.returnValue(room_ids) @defer.inlineCallbacks - def _do_local_membership_update(self, event, context): - yield run_on_reactor() - - target_user = UserID.from_string(event.state_key) - - yield self.handle_new_client_event( - event, - context, - extra_users=[target_user], - ) - - @defer.inlineCallbacks def do_3pid_invite( self, room_id, @@ -676,7 +651,7 @@ class RoomMemberHandler(BaseHandler): medium, address, id_server, - token_id, + requester, txn_id ): invitee = yield self._lookup_3pid( @@ -684,19 +659,12 @@ class RoomMemberHandler(BaseHandler): ) if invitee: - # make sure it looks like a user ID; it'll throw if it's invalid. - UserID.from_string(invitee) - yield self.hs.get_handlers().message_handler.create_and_send_event( - { - "type": EventTypes.Member, - "content": { - "membership": unicode("invite") - }, - "room_id": room_id, - "sender": inviter.to_string(), - "state_key": invitee, - }, - token_id=token_id, + handler = self.hs.get_handlers().room_member_handler + yield handler.update_membership( + requester, + UserID.from_string(invitee), + room_id, + "invite", txn_id=txn_id, ) else: @@ -706,7 +674,7 @@ class RoomMemberHandler(BaseHandler): address, room_id, inviter, - token_id, + requester.access_token_id, txn_id=txn_id ) @@ -801,7 +769,7 @@ class RoomMemberHandler(BaseHandler): if room_avatar_event: room_avatar_url = room_avatar_event.content.get("url", "") - token, public_key, key_validity_url, display_name = ( + token, public_keys, fallback_public_key, display_name = ( yield self._ask_id_server_for_third_party_invite( id_server=id_server, medium=medium, @@ -816,14 +784,18 @@ class RoomMemberHandler(BaseHandler): inviter_avatar_url=inviter_avatar_url ) ) + msg_handler = self.hs.get_handlers().message_handler - yield msg_handler.create_and_send_event( + yield msg_handler.create_and_send_nonmember_event( { "type": EventTypes.ThirdPartyInvite, "content": { "display_name": display_name, - "key_validity_url": key_validity_url, - "public_key": public_key, + "public_keys": public_keys, + + # For backwards compatibility: + "key_validity_url": fallback_public_key["key_validity_url"], + "public_key": fallback_public_key["public_key"], }, "room_id": room_id, "sender": user.to_string(), @@ -848,6 +820,41 @@ class RoomMemberHandler(BaseHandler): inviter_display_name, inviter_avatar_url ): + """ + Asks an identity server for a third party invite. + + :param id_server (str): hostname + optional port for the identity server. + :param medium (str): The literal string "email". + :param address (str): The third party address being invited. + :param room_id (str): The ID of the room to which the user is invited. + :param inviter_user_id (str): The user ID of the inviter. + :param room_alias (str): An alias for the room, for cosmetic + notifications. + :param room_avatar_url (str): The URL of the room's avatar, for cosmetic + notifications. + :param room_join_rules (str): The join rules of the email + (e.g. "public"). + :param room_name (str): The m.room.name of the room. + :param inviter_display_name (str): The current display name of the + inviter. + :param inviter_avatar_url (str): The URL of the inviter's avatar. + + :return: A deferred tuple containing: + token (str): The token which must be signed to prove authenticity. + public_keys ([{"public_key": str, "key_validity_url": str}]): + public_key is a base64-encoded ed25519 public key. + fallback_public_key: One element from public_keys. + display_name (str): A user-friendly name to represent the invited + user. + """ + + registration_handler = self.hs.get_handlers().registration_handler + guest_access_token = yield registration_handler.guest_access_token_for( + medium=medium, + address=address, + inviter_user_id=inviter_user_id, + ) + is_url = "%s%s/_matrix/identity/api/v1/store-invite" % ( id_server_scheme, id_server, ) @@ -864,16 +871,26 @@ class RoomMemberHandler(BaseHandler): "sender": inviter_user_id, "sender_display_name": inviter_display_name, "sender_avatar_url": inviter_avatar_url, + "guest_access_token": guest_access_token, } ) # TODO: Check for success token = data["token"] - public_key = data["public_key"] + public_keys = data.get("public_keys", []) + if "public_key" in data: + fallback_public_key = { + "public_key": data["public_key"], + "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( + id_server_scheme, id_server, + ), + } + else: + fallback_public_key = public_keys[0] + + if not public_keys: + public_keys.append(fallback_public_key) display_name = data["display_name"] - key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( - id_server_scheme, id_server, - ) - defer.returnValue((token, public_key, key_validity_url, display_name)) + defer.returnValue((token, public_keys, fallback_public_key, display_name)) def forget(self, user, room_id): return self.store.forget(user.to_string(), room_id) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 1d0f0058a2..fded6e4009 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -121,7 +121,11 @@ class SyncResult(collections.namedtuple("SyncResult", [ events. """ return bool( - self.presence or self.joined or self.invited or self.archived + self.presence or + self.joined or + self.invited or + self.archived or + self.account_data ) @@ -582,6 +586,28 @@ class SyncHandler(BaseHandler): if room_sync: joined.append(room_sync) + # For each newly joined room, we want to send down presence of + # existing users. + presence_handler = self.hs.get_handlers().presence_handler + extra_presence_users = set() + for room_id in newly_joined_rooms: + users = yield self.store.get_users_in_room(event.room_id) + extra_presence_users.update(users) + + # For each new member, send down presence. + for joined_sync in joined: + it = itertools.chain(joined_sync.timeline.events, joined_sync.state.values()) + for event in it: + if event.type == EventTypes.Member: + if event.membership == Membership.JOIN: + extra_presence_users.add(event.state_key) + + states = yield presence_handler.get_states( + [u for u in extra_presence_users if u != user_id], + as_event=True, + ) + presence.extend(states) + account_data_for_user = sync_config.filter_collection.filter_account_data( self.account_data_for_user(account_data) ) @@ -623,7 +649,6 @@ class SyncHandler(BaseHandler): recents = yield self._filter_events_for_client( sync_config.user.to_string(), recents, - is_peeking=sync_config.is_guest, ) else: recents = [] @@ -645,7 +670,6 @@ class SyncHandler(BaseHandler): loaded_recents = yield self._filter_events_for_client( sync_config.user.to_string(), loaded_recents, - is_peeking=sync_config.is_guest, ) loaded_recents.extend(recents) recents = loaded_recents @@ -825,14 +849,20 @@ class SyncHandler(BaseHandler): with Measure(self.clock, "compute_state_delta"): if full_state: if batch: + current_state = yield self.store.get_state_for_event( + batch.events[-1].event_id + ) + state = yield self.store.get_state_for_event( batch.events[0].event_id ) else: - state = yield self.get_state_at( + current_state = yield self.get_state_at( room_id, stream_position=now_token ) + state = current_state + timeline_state = { (event.type, event.state_key): event for event in batch.events if event.is_state() @@ -842,12 +872,17 @@ class SyncHandler(BaseHandler): timeline_contains=timeline_state, timeline_start=state, previous={}, + current=current_state, ) elif batch.limited: state_at_previous_sync = yield self.get_state_at( room_id, stream_position=since_token ) + current_state = yield self.store.get_state_for_event( + batch.events[-1].event_id + ) + state_at_timeline_start = yield self.store.get_state_for_event( batch.events[0].event_id ) @@ -861,6 +896,7 @@ class SyncHandler(BaseHandler): timeline_contains=timeline_state, timeline_start=state_at_timeline_start, previous=state_at_previous_sync, + current=current_state, ) else: state = {} @@ -920,7 +956,7 @@ def _action_has_highlight(actions): return False -def _calculate_state(timeline_contains, timeline_start, previous): +def _calculate_state(timeline_contains, timeline_start, previous, current): """Works out what state to include in a sync response. Args: @@ -928,6 +964,7 @@ def _calculate_state(timeline_contains, timeline_start, previous): timeline_start (dict): state at the start of the timeline previous (dict): state at the end of the previous sync (or empty dict if this is an initial sync) + current (dict): state at the end of the timeline Returns: dict @@ -938,14 +975,16 @@ def _calculate_state(timeline_contains, timeline_start, previous): timeline_contains.values(), previous.values(), timeline_start.values(), + current.values(), ) } + c_ids = set(e.event_id for e in current.values()) tc_ids = set(e.event_id for e in timeline_contains.values()) p_ids = set(e.event_id for e in previous.values()) ts_ids = set(e.event_id for e in timeline_start.values()) - state_ids = (ts_ids - p_ids) - tc_ids + state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids evs = (event_id_to_state[e] for e in state_ids) return { diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index b16d0017df..8ce27f49ec 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -25,6 +25,7 @@ from synapse.types import UserID import logging from collections import namedtuple +import ujson as json logger = logging.getLogger(__name__) @@ -219,6 +220,19 @@ class TypingNotificationHandler(BaseHandler): "typing_key", self._latest_room_serial, rooms=[room_id] ) + def get_all_typing_updates(self, last_id, current_id): + # TODO: Work out a way to do this without scanning the entire state. + rows = [] + for room_id, serial in self._room_serials.items(): + if last_id < serial and serial <= current_id: + typing = self._room_typing[room_id] + typing_bytes = json.dumps([ + u.to_string() for u in typing + ], ensure_ascii=False) + rows.append((serial, room_id, typing_bytes)) + rows.sort() + return rows + class TypingNotificationEventSource(object): def __init__(self, hs): diff --git a/synapse/http/server.py b/synapse/http/server.py index a90e2e1125..b17b190ee5 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -367,10 +367,29 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False, "Origin, X-Requested-With, Content-Type, Accept") request.write(json_bytes) - request.finish() + finish_request(request) return NOT_DONE_YET +def finish_request(request): + """ Finish writing the response to the request. + + Twisted throws a RuntimeException if the connection closed before the + response was written but doesn't provide a convenient or reliable way to + determine if the connection was closed. So we catch and log the RuntimeException + + You might think that ``request.notifyFinish`` could be used to tell if the + request was finished. However the deferred it returns won't fire if the + connection was already closed, meaning we'd have to have called the method + right at the start of the request. By the time we want to write the response + it will already be too late. + """ + try: + request.finish() + except RuntimeError as e: + logger.info("Connection disconnected before response was written: %r", e) + + def _request_user_agent_is_curl(request): user_agents = request.requestHeaders.getRawHeaders( "User-Agent", default=[] diff --git a/synapse/notifier.py b/synapse/notifier.py index 560866b26e..3c36a20868 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -159,6 +159,8 @@ class Notifier(object): self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS ) + self.replication_deferred = ObservableDeferred(defer.Deferred()) + # This is not a very cheap test to perform, but it's only executed # when rendering the metrics page, which is likely once per minute at # most when scraping it. @@ -207,6 +209,8 @@ class Notifier(object): )) self._notify_pending_new_room_events(max_room_stream_id) + self.notify_replication() + def _notify_pending_new_room_events(self, max_room_stream_id): """Notify for the room events that were queued waiting for a previous event to be persisted. @@ -276,6 +280,8 @@ class Notifier(object): except: logger.exception("Failed to notify listener") + self.notify_replication() + @defer.inlineCallbacks def wait_for_events(self, user_id, timeout, callback, room_ids=None, from_token=StreamToken("s0", "0", "0", "0", "0")): @@ -479,3 +485,45 @@ class Notifier(object): room_streams = self.room_to_user_streams.setdefault(room_id, set()) room_streams.add(new_user_stream) new_user_stream.rooms.add(room_id) + + def notify_replication(self): + """Notify the any replication listeners that there's a new event""" + with PreserveLoggingContext(): + deferred = self.replication_deferred + self.replication_deferred = ObservableDeferred(defer.Deferred()) + deferred.callback(None) + + @defer.inlineCallbacks + def wait_for_replication(self, callback, timeout): + """Wait for an event to happen. + + :param callback: + Gets called whenever an event happens. If this returns a truthy + value then ``wait_for_replication`` returns, otherwise it waits + for another event. + :param int timeout: + How many milliseconds to wait for callback return a truthy value. + :returns: + A deferred that resolves with the value returned by the callback. + """ + listener = _NotificationListener(None) + + def timed_out(): + listener.deferred.cancel() + + timer = self.clock.call_later(timeout / 1000., timed_out) + while True: + listener.deferred = self.replication_deferred.observe() + result = yield callback() + if result: + break + + try: + with PreserveLoggingContext(): + yield listener.deferred + except defer.CancelledError: + break + + self.clock.cancel_call_later(timer, ignore_errs=True) + + defer.returnValue(result) diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 8da2d8716c..4c6c3b83a2 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -47,14 +47,13 @@ class Pusher(object): MAX_BACKOFF = 60 * 60 * 1000 GIVE_UP_AFTER = 24 * 60 * 60 * 1000 - def __init__(self, _hs, profile_tag, user_id, app_id, + def __init__(self, _hs, user_id, app_id, app_display_name, device_display_name, pushkey, pushkey_ts, data, last_token, last_success, failing_since): self.hs = _hs self.evStreamHandler = self.hs.get_handlers().event_stream_handler self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() - self.profile_tag = profile_tag self.user_id = user_id self.app_id = app_id self.app_display_name = app_display_name @@ -186,8 +185,8 @@ class Pusher(object): processed = False rule_evaluator = yield \ - push_rule_evaluator.evaluator_for_user_id_and_profile_tag( - self.user_id, self.profile_tag, single_event['room_id'], self.store + push_rule_evaluator.evaluator_for_user_id( + self.user_id, single_event['room_id'], self.store ) actions = yield rule_evaluator.actions_for_event(single_event) diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index e0da0868ec..c6c1dc769e 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -44,5 +44,5 @@ class ActionGenerator: ) context.push_actions = [ - (uid, None, actions) for uid, actions in actions_by_user.items() + (uid, actions) for uid, actions in actions_by_user.items() ] diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 0832c77cb4..86a2998bcc 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -13,46 +13,67 @@ # limitations under the License. from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP +import copy def list_with_base_rules(rawrules): + """Combine the list of rules set by the user with the default push rules + + :param list rawrules: The rules the user has modified or set. + :returns: A new list with the rules set by the user combined with the + defaults. + """ ruleslist = [] + # Grab the base rules that the user has modified. + # The modified base rules have a priority_class of -1. + modified_base_rules = { + r['rule_id']: r for r in rawrules if r['priority_class'] < 0 + } + + # Remove the modified base rules from the list, They'll be added back + # in the default postions in the list. + rawrules = [r for r in rawrules if r['priority_class'] >= 0] + # shove the server default rules for each kind onto the end of each current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1] ruleslist.extend(make_base_prepend_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class] + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules )) for r in rawrules: if r['priority_class'] < current_prio_class: while r['priority_class'] < current_prio_class: ruleslist.extend(make_base_append_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class] + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, )) current_prio_class -= 1 if current_prio_class > 0: ruleslist.extend(make_base_prepend_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class] + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, )) ruleslist.append(r) while current_prio_class > 0: ruleslist.extend(make_base_append_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class] + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, )) current_prio_class -= 1 if current_prio_class > 0: ruleslist.extend(make_base_prepend_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class] + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, )) return ruleslist -def make_base_append_rules(kind): +def make_base_append_rules(kind, modified_base_rules): rules = [] if kind == 'override': @@ -62,15 +83,31 @@ def make_base_append_rules(kind): elif kind == 'content': rules = BASE_APPEND_CONTENT_RULES + # Copy the rules before modifying them + rules = copy.deepcopy(rules) + for r in rules: + # Only modify the actions, keep the conditions the same. + modified = modified_base_rules.get(r['rule_id']) + if modified: + r['actions'] = modified['actions'] + return rules -def make_base_prepend_rules(kind): +def make_base_prepend_rules(kind, modified_base_rules): rules = [] if kind == 'override': rules = BASE_PREPEND_OVERRIDE_RULES + # Copy the rules before modifying them + rules = copy.deepcopy(rules) + for r in rules: + # Only modify the actions, keep the conditions the same. + modified = modified_base_rules.get(r['rule_id']) + if modified: + r['actions'] = modified['actions'] + return rules @@ -263,18 +300,24 @@ BASE_APPEND_UNDERRIDE_RULES = [ ] +BASE_RULE_IDS = set() + for r in BASE_APPEND_CONTENT_RULES: r['priority_class'] = PRIORITY_CLASS_MAP['content'] r['default'] = True + BASE_RULE_IDS.add(r['rule_id']) for r in BASE_PREPEND_OVERRIDE_RULES: r['priority_class'] = PRIORITY_CLASS_MAP['override'] r['default'] = True + BASE_RULE_IDS.add(r['rule_id']) for r in BASE_APPEND_OVRRIDE_RULES: r['priority_class'] = PRIORITY_CLASS_MAP['override'] r['default'] = True + BASE_RULE_IDS.add(r['rule_id']) for r in BASE_APPEND_UNDERRIDE_RULES: r['priority_class'] = PRIORITY_CLASS_MAP['underride'] r['default'] = True + BASE_RULE_IDS.add(r['rule_id']) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 8ac5ceb9ef..5d8be483e5 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -103,7 +103,7 @@ class BulkPushRuleEvaluator: users_dict = yield self.store.are_guests(self.rules_by_user.keys()) - filtered_by_user = yield handler._filter_events_for_clients( + filtered_by_user = yield handler.filter_events_for_clients( users_dict.items(), [event], {event.event_id: current_state} ) @@ -152,7 +152,7 @@ def _condition_checker(evaluator, conditions, uid, display_name, cache): elif res is True: continue - res = evaluator.matches(cond, uid, display_name, None) + res = evaluator.matches(cond, uid, display_name) if _id: cache[_id] = bool(res) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index cdc4494928..9be4869360 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -23,12 +23,11 @@ logger = logging.getLogger(__name__) class HttpPusher(Pusher): - def __init__(self, _hs, profile_tag, user_id, app_id, + def __init__(self, _hs, user_id, app_id, app_display_name, device_display_name, pushkey, pushkey_ts, data, last_token, last_success, failing_since): super(HttpPusher, self).__init__( _hs, - profile_tag, user_id, app_id, app_display_name, diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 2a2b4437dc..98e2a2015e 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -33,7 +33,7 @@ INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") @defer.inlineCallbacks -def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store): +def evaluator_for_user_id(user_id, room_id, store): rawrules = yield store.get_push_rules_for_user(user_id) enabled_map = yield store.get_push_rules_enabled_for_user(user_id) our_member_event = yield store.get_current_state( @@ -43,7 +43,7 @@ def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store): ) defer.returnValue(PushRuleEvaluator( - user_id, profile_tag, rawrules, enabled_map, + user_id, rawrules, enabled_map, room_id, our_member_event, store )) @@ -77,10 +77,9 @@ def _room_member_count(ev, condition, room_member_count): class PushRuleEvaluator: DEFAULT_ACTIONS = [] - def __init__(self, user_id, profile_tag, raw_rules, enabled_map, room_id, + def __init__(self, user_id, raw_rules, enabled_map, room_id, our_member_event, store): self.user_id = user_id - self.profile_tag = profile_tag self.room_id = room_id self.our_member_event = our_member_event self.store = store @@ -152,7 +151,7 @@ class PushRuleEvaluator: matches = True for c in conditions: matches = evaluator.matches( - c, self.user_id, my_display_name, self.profile_tag + c, self.user_id, my_display_name ) if not matches: break @@ -189,13 +188,9 @@ class PushRuleEvaluatorForEvent(object): # Maps strings of e.g. 'content.body' -> event["content"]["body"] self._value_cache = _flatten_dict(event) - def matches(self, condition, user_id, display_name, profile_tag): + def matches(self, condition, user_id, display_name): if condition['kind'] == 'event_match': return self._event_match(condition, user_id) - elif condition['kind'] == 'device': - if 'profile_tag' not in condition: - return True - return condition['profile_tag'] == profile_tag elif condition['kind'] == 'contains_display_name': return self._contains_display_name(display_name) elif condition['kind'] == 'room_member_count': diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index d7dcb2de4b..a05aa5f661 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -29,6 +29,7 @@ class PusherPool: def __init__(self, _hs): self.hs = _hs self.store = self.hs.get_datastore() + self.clock = self.hs.get_clock() self.pushers = {} self.last_pusher_started = -1 @@ -38,8 +39,11 @@ class PusherPool: self._start_pushers(pushers) @defer.inlineCallbacks - def add_pusher(self, user_id, access_token, profile_tag, kind, app_id, - app_display_name, device_display_name, pushkey, lang, data): + def add_pusher(self, user_id, access_token, kind, app_id, + app_display_name, device_display_name, pushkey, lang, data, + profile_tag=""): + time_now_msec = self.clock.time_msec() + # we try to create the pusher just to validate the config: it # will then get pulled out of the database, # recreated, added and started: this means we have only one @@ -47,23 +51,31 @@ class PusherPool: self._create_pusher({ "user_name": user_id, "kind": kind, - "profile_tag": profile_tag, "app_id": app_id, "app_display_name": app_display_name, "device_display_name": device_display_name, "pushkey": pushkey, - "ts": self.hs.get_clock().time_msec(), + "ts": time_now_msec, "lang": lang, "data": data, "last_token": None, "last_success": None, "failing_since": None }) - yield self._add_pusher_to_store( - user_id, access_token, profile_tag, kind, app_id, - app_display_name, device_display_name, - pushkey, lang, data + yield self.store.add_pusher( + user_id=user_id, + access_token=access_token, + kind=kind, + app_id=app_id, + app_display_name=app_display_name, + device_display_name=device_display_name, + pushkey=pushkey, + pushkey_ts=time_now_msec, + lang=lang, + data=data, + profile_tag=profile_tag, ) + yield self._refresh_pusher(app_id, pushkey, user_id) @defer.inlineCallbacks def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey, @@ -94,30 +106,10 @@ class PusherPool: ) yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) - @defer.inlineCallbacks - def _add_pusher_to_store(self, user_id, access_token, profile_tag, kind, - app_id, app_display_name, device_display_name, - pushkey, lang, data): - yield self.store.add_pusher( - user_id=user_id, - access_token=access_token, - profile_tag=profile_tag, - kind=kind, - app_id=app_id, - app_display_name=app_display_name, - device_display_name=device_display_name, - pushkey=pushkey, - pushkey_ts=self.hs.get_clock().time_msec(), - lang=lang, - data=data, - ) - yield self._refresh_pusher(app_id, pushkey, user_id) - def _create_pusher(self, pusherdict): if pusherdict['kind'] == 'http': return HttpPusher( self.hs, - profile_tag=pusherdict['profile_tag'], user_id=pusherdict['user_name'], app_id=pusherdict['app_id'], app_display_name=pusherdict['app_display_name'], diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 75bf3d13aa..35933324a4 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) REQUIREMENTS = { "frozendict>=0.4": ["frozendict"], - "unpaddedbase64>=1.0.1": ["unpaddedbase64>=1.0.1"], + "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"], "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"], "signedjson>=1.0.0": ["signedjson>=1.0.0"], "pynacl==0.3.0": ["nacl==0.3.0", "nacl.bindings"], diff --git a/synapse/replication/__init__.py b/synapse/replication/__init__.py new file mode 100644 index 0000000000..b7df13c9ee --- /dev/null +++ b/synapse/replication/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py new file mode 100644 index 0000000000..e0d039518d --- /dev/null +++ b/synapse/replication/resource.py @@ -0,0 +1,320 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.http.servlet import parse_integer, parse_string +from synapse.http.server import request_handler, finish_request + +from twisted.web.resource import Resource +from twisted.web.server import NOT_DONE_YET +from twisted.internet import defer + +import ujson as json + +import collections +import logging + +logger = logging.getLogger(__name__) + +REPLICATION_PREFIX = "/_synapse/replication" + +STREAM_NAMES = ( + ("events",), + ("presence",), + ("typing",), + ("receipts",), + ("user_account_data", "room_account_data", "tag_account_data",), + ("backfill",), +) + + +class ReplicationResource(Resource): + """ + HTTP endpoint for extracting data from synapse. + + The streams of data returned by the endpoint are controlled by the + parameters given to the API. To return a given stream pass a query + parameter with a position in the stream to return data from or the + special value "-1" to return data from the start of the stream. + + If there is no data for any of the supplied streams after the given + position then the request will block until there is data for one + of the streams. This allows clients to long-poll this API. + + The possible streams are: + + * "streams": A special stream returing the positions of other streams. + * "events": The new events seen on the server. + * "presence": Presence updates. + * "typing": Typing updates. + * "receipts": Receipt updates. + * "user_account_data": Top-level per user account data. + * "room_account_data: Per room per user account data. + * "tag_account_data": Per room per user tags. + * "backfill": Old events that have been backfilled from other servers. + + The API takes two additional query parameters: + + * "timeout": How long to wait before returning an empty response. + * "limit": The maximum number of rows to return for the selected streams. + + The response is a JSON object with keys for each stream with updates. Under + each key is a JSON object with: + + * "postion": The current position of the stream. + * "field_names": The names of the fields in each row. + * "rows": The updates as an array of arrays. + + There are a number of ways this API could be used: + + 1) To replicate the contents of the backing database to another database. + 2) To be notified when the contents of a shared backing database changes. + 3) To "tail" the activity happening on a server for debugging. + + In the first case the client would track all of the streams and store it's + own copy of the data. + + In the second case the client might theoretically just be able to follow + the "streams" stream to track where the other streams are. However in + practise it will probably need to get the contents of the streams in + order to expire the any in-memory caches. Whether it gets the contents + of the streams from this replication API or directly from the backing + store is a matter of taste. + + In the third case the client would use the "streams" stream to find what + streams are available and their current positions. Then it can start + long-polling this replication API for new data on those streams. + """ + + isLeaf = True + + def __init__(self, hs): + Resource.__init__(self) # Resource is old-style, so no super() + + self.version_string = hs.version_string + self.store = hs.get_datastore() + self.sources = hs.get_event_sources() + self.presence_handler = hs.get_handlers().presence_handler + self.typing_handler = hs.get_handlers().typing_notification_handler + self.notifier = hs.notifier + + def render_GET(self, request): + self._async_render_GET(request) + return NOT_DONE_YET + + @defer.inlineCallbacks + def current_replication_token(self): + stream_token = yield self.sources.get_current_token() + backfill_token = yield self.store.get_current_backfill_token() + + defer.returnValue(_ReplicationToken( + stream_token.room_stream_id, + int(stream_token.presence_key), + int(stream_token.typing_key), + int(stream_token.receipt_key), + int(stream_token.account_data_key), + backfill_token, + )) + + @request_handler + @defer.inlineCallbacks + def _async_render_GET(self, request): + limit = parse_integer(request, "limit", 100) + timeout = parse_integer(request, "timeout", 10 * 1000) + + request.setHeader(b"Content-Type", b"application/json") + writer = _Writer(request) + + @defer.inlineCallbacks + def replicate(): + current_token = yield self.current_replication_token() + logger.info("Replicating up to %r", current_token) + + yield self.account_data(writer, current_token, limit) + yield self.events(writer, current_token, limit) + yield self.presence(writer, current_token) # TODO: implement limit + yield self.typing(writer, current_token) # TODO: implement limit + yield self.receipts(writer, current_token, limit) + self.streams(writer, current_token) + + logger.info("Replicated %d rows", writer.total) + defer.returnValue(writer.total) + + yield self.notifier.wait_for_replication(replicate, timeout) + + writer.finish() + + def streams(self, writer, current_token): + request_token = parse_string(writer.request, "streams") + + streams = [] + + if request_token is not None: + if request_token == "-1": + for names, position in zip(STREAM_NAMES, current_token): + streams.extend((name, position) for name in names) + else: + items = zip( + STREAM_NAMES, + current_token, + _ReplicationToken(request_token) + ) + for names, current_id, last_id in items: + if last_id < current_id: + streams.extend((name, current_id) for name in names) + + if streams: + writer.write_header_and_rows( + "streams", streams, ("name", "position"), + position=str(current_token) + ) + + @defer.inlineCallbacks + def events(self, writer, current_token, limit): + request_events = parse_integer(writer.request, "events") + request_backfill = parse_integer(writer.request, "backfill") + + if request_events is not None or request_backfill is not None: + if request_events is None: + request_events = current_token.events + if request_backfill is None: + request_backfill = current_token.backfill + events_rows, backfill_rows = yield self.store.get_all_new_events( + request_backfill, request_events, + current_token.backfill, current_token.events, + limit + ) + writer.write_header_and_rows( + "events", events_rows, ("position", "internal", "json") + ) + writer.write_header_and_rows( + "backfill", backfill_rows, ("position", "internal", "json") + ) + + @defer.inlineCallbacks + def presence(self, writer, current_token): + current_position = current_token.presence + + request_presence = parse_integer(writer.request, "presence") + + if request_presence is not None: + presence_rows = yield self.presence_handler.get_all_presence_updates( + request_presence, current_position + ) + writer.write_header_and_rows("presence", presence_rows, ( + "position", "user_id", "state", "last_active_ts", + "last_federation_update_ts", "last_user_sync_ts", + "status_msg", "currently_active", + )) + + @defer.inlineCallbacks + def typing(self, writer, current_token): + current_position = current_token.presence + + request_typing = parse_integer(writer.request, "typing") + + if request_typing is not None: + typing_rows = yield self.typing_handler.get_all_typing_updates( + request_typing, current_position + ) + writer.write_header_and_rows("typing", typing_rows, ( + "position", "room_id", "typing" + )) + + @defer.inlineCallbacks + def receipts(self, writer, current_token, limit): + current_position = current_token.receipts + + request_receipts = parse_integer(writer.request, "receipts") + + if request_receipts is not None: + receipts_rows = yield self.store.get_all_updated_receipts( + request_receipts, current_position, limit + ) + writer.write_header_and_rows("receipts", receipts_rows, ( + "position", "room_id", "receipt_type", "user_id", "event_id", "data" + )) + + @defer.inlineCallbacks + def account_data(self, writer, current_token, limit): + current_position = current_token.account_data + + user_account_data = parse_integer(writer.request, "user_account_data") + room_account_data = parse_integer(writer.request, "room_account_data") + tag_account_data = parse_integer(writer.request, "tag_account_data") + + if user_account_data is not None or room_account_data is not None: + if user_account_data is None: + user_account_data = current_position + if room_account_data is None: + room_account_data = current_position + user_rows, room_rows = yield self.store.get_all_updated_account_data( + user_account_data, room_account_data, current_position, limit + ) + writer.write_header_and_rows("user_account_data", user_rows, ( + "position", "user_id", "type", "content" + )) + writer.write_header_and_rows("room_account_data", room_rows, ( + "position", "user_id", "room_id", "type", "content" + )) + + if tag_account_data is not None: + tag_rows = yield self.store.get_all_updated_tags( + tag_account_data, current_position, limit + ) + writer.write_header_and_rows("tag_account_data", tag_rows, ( + "position", "user_id", "room_id", "tags" + )) + + +class _Writer(object): + """Writes the streams as a JSON object as the response to the request""" + def __init__(self, request): + self.streams = {} + self.request = request + self.total = 0 + + def write_header_and_rows(self, name, rows, fields, position=None): + if not rows: + return + + if position is None: + position = rows[-1][0] + + self.streams[name] = { + "position": str(position), + "field_names": fields, + "rows": rows, + } + + self.total += len(rows) + + def finish(self): + self.request.write(json.dumps(self.streams, ensure_ascii=False)) + finish_request(self.request) + + +class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( + "events", "presence", "typing", "receipts", "account_data", "backfill", +))): + __slots__ = [] + + def __new__(cls, *args): + if len(args) == 1: + return cls(*(int(value) for value in args[0].split("_"))) + else: + return super(_ReplicationToken, cls).__new__(cls, *args) + + def __str__(self): + return "_".join(str(value) for value in self) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 7199113dac..f13272da8e 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -17,6 +17,8 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, LoginError, Codes from synapse.types import UserID +from synapse.http.server import finish_request + from base import ClientV1RestServlet, client_path_patterns import simplejson as json @@ -263,7 +265,7 @@ class SAML2RestServlet(ClientV1RestServlet): '?status=authenticated&access_token=' + token + '&user_id=' + user_id + '&ava=' + urllib.quote(json.dumps(saml2_auth.ava))) - request.finish() + finish_request(request) defer.returnValue(None) defer.returnValue((200, {"status": "authenticated", "user_id": user_id, "token": token, @@ -272,7 +274,7 @@ class SAML2RestServlet(ClientV1RestServlet): request.redirect(urllib.unquote( request.args['RelayState'][0]) + '?status=not_authenticated') - request.finish() + finish_request(request) defer.returnValue(None) defer.returnValue((200, {"status": "not_authenticated"})) @@ -309,7 +311,7 @@ class CasRedirectServlet(ClientV1RestServlet): "service": "%s?%s" % (hs_redirect_url, client_redirect_url_param) }) request.redirect("%s?%s" % (self.cas_server_url, service_param)) - request.finish() + finish_request(request) class CasTicketServlet(ClientV1RestServlet): @@ -362,7 +364,7 @@ class CasTicketServlet(ClientV1RestServlet): redirect_url = self.add_login_token_to_redirect_url(client_redirect_url, login_token) request.redirect(redirect_url) - request.finish() + finish_request(request) def add_login_token_to_redirect_url(self, url, token): url_parts = list(urlparse.urlparse(url)) @@ -402,10 +404,12 @@ def _parse_json(request): try: content = json.loads(request.content.read()) if type(content) != dict: - raise SynapseError(400, "Content must be a JSON object.") + raise SynapseError( + 400, "Content must be a JSON object.", errcode=Codes.BAD_JSON + ) return content except ValueError: - raise SynapseError(400, "Content not JSON.") + raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index a6f8754e32..bbfa1d6ac4 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -17,7 +17,7 @@ """ from twisted.internet import defer -from synapse.api.errors import SynapseError +from synapse.api.errors import SynapseError, AuthError from synapse.types import UserID from .base import ClientV1RestServlet, client_path_patterns @@ -35,8 +35,15 @@ class PresenceStatusRestServlet(ClientV1RestServlet): requester = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) - state = yield self.handlers.presence_handler.get_state( - target_user=user, auth_user=requester.user) + if requester.user != user: + allowed = yield self.handlers.presence_handler.is_visible( + observed_user=user, observer_user=requester.user, + ) + + if not allowed: + raise AuthError(403, "You are not allowed to see their presence.") + + state = yield self.handlers.presence_handler.get_state(target_user=user) defer.returnValue((200, state)) @@ -45,6 +52,9 @@ class PresenceStatusRestServlet(ClientV1RestServlet): requester = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) + if requester.user != user: + raise AuthError(403, "Can only set your own presence state") + state = {} try: content = json.loads(request.content.read()) @@ -63,8 +73,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): except: raise SynapseError(400, "Unable to parse state") - yield self.handlers.presence_handler.set_state( - target_user=user, auth_user=requester.user, state=state) + yield self.handlers.presence_handler.set_state(user, state) defer.returnValue((200, {})) @@ -87,11 +96,8 @@ class PresenceListRestServlet(ClientV1RestServlet): raise SynapseError(400, "Cannot get another user's presence list") presence = yield self.handlers.presence_handler.get_presence_list( - observer_user=user, accepted=True) - - for p in presence: - observed_user = p.pop("observed_user") - p["user_id"] = observed_user.to_string() + observer_user=user, accepted=True + ) defer.returnValue((200, presence)) diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 96633a176c..970a019223 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -22,7 +22,7 @@ from .base import ClientV1RestServlet, client_path_patterns from synapse.storage.push_rule import ( InconsistentRuleException, RuleNotFoundException ) -import synapse.push.baserules as baserules +from synapse.push.baserules import list_with_base_rules, BASE_RULE_IDS from synapse.push.rulekinds import ( PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP ) @@ -55,12 +55,15 @@ class PushRuleRestServlet(ClientV1RestServlet): yield self.set_rule_attr(requester.user.to_string(), spec, content) defer.returnValue((200, {})) + if spec['rule_id'].startswith('.'): + # Rule ids starting with '.' are reserved for server default rules. + raise SynapseError(400, "cannot add new rule_ids that start with '.'") + try: (conditions, actions) = _rule_tuple_from_request_object( spec['template'], spec['rule_id'], content, - device=spec['device'] if 'device' in spec else None ) except InvalidRuleException as e: raise SynapseError(400, e.message) @@ -129,7 +132,7 @@ class PushRuleRestServlet(ClientV1RestServlet): ruleslist.append(rule) # We're going to be mutating this a lot, so do a deep copy - ruleslist = copy.deepcopy(baserules.list_with_base_rules(ruleslist)) + ruleslist = copy.deepcopy(list_with_base_rules(ruleslist)) rules = {'global': {}, 'device': {}} @@ -153,23 +156,7 @@ class PushRuleRestServlet(ClientV1RestServlet): elif pattern_type == "user_localpart": c["pattern"] = user.localpart - if r['priority_class'] > PRIORITY_CLASS_MAP['override']: - # per-device rule - profile_tag = _profile_tag_from_conditions(r["conditions"]) - r = _strip_device_condition(r) - if not profile_tag: - continue - if profile_tag not in rules['device']: - rules['device'][profile_tag] = {} - rules['device'][profile_tag] = ( - _add_empty_priority_class_arrays( - rules['device'][profile_tag] - ) - ) - - rulearray = rules['device'][profile_tag][template_name] - else: - rulearray = rules['global'][template_name] + rulearray = rules['global'][template_name] template_rule = _rule_to_template(r) if template_rule: @@ -195,24 +182,6 @@ class PushRuleRestServlet(ClientV1RestServlet): path = path[1:] result = _filter_ruleset_with_path(rules['global'], path) defer.returnValue((200, result)) - elif path[0] == 'device': - path = path[1:] - if path == []: - raise UnrecognizedRequestError( - PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR - ) - if path[0] == '': - defer.returnValue((200, rules['device'])) - - profile_tag = path[0] - path = path[1:] - if profile_tag not in rules['device']: - ret = {} - ret = _add_empty_priority_class_arrays(ret) - defer.returnValue((200, ret)) - ruleset = rules['device'][profile_tag] - result = _filter_ruleset_with_path(ruleset, path) - defer.returnValue((200, result)) else: raise UnrecognizedRequestError() @@ -232,13 +201,17 @@ class PushRuleRestServlet(ClientV1RestServlet): return self.hs.get_datastore().set_push_rule_enabled( user_id, namespaced_rule_id, val ) - else: - raise UnrecognizedRequestError() - - def get_rule_attr(self, user_id, namespaced_rule_id, attr): - if attr == 'enabled': - return self.hs.get_datastore().get_push_rule_enabled_by_user_rule_id( - user_id, namespaced_rule_id + elif spec['attr'] == 'actions': + actions = val.get('actions') + _check_actions(actions) + namespaced_rule_id = _namespaced_rule_id_from_spec(spec) + rule_id = spec['rule_id'] + is_default_rule = rule_id.startswith(".") + if is_default_rule: + if namespaced_rule_id not in BASE_RULE_IDS: + raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) + return self.hs.get_datastore().set_push_rule_actions( + user_id, namespaced_rule_id, actions, is_default_rule ) else: raise UnrecognizedRequestError() @@ -252,16 +225,9 @@ def _rule_spec_from_path(path): scope = path[1] path = path[2:] - if scope not in ['global', 'device']: + if scope != 'global': raise UnrecognizedRequestError() - device = None - if scope == 'device': - if len(path) == 0: - raise UnrecognizedRequestError() - device = path[0] - path = path[1:] - if len(path) == 0: raise UnrecognizedRequestError() @@ -278,8 +244,6 @@ def _rule_spec_from_path(path): 'template': template, 'rule_id': rule_id } - if device: - spec['profile_tag'] = device path = path[1:] @@ -289,7 +253,7 @@ def _rule_spec_from_path(path): return spec -def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None): +def _rule_tuple_from_request_object(rule_template, rule_id, req_obj): if rule_template in ['override', 'underride']: if 'conditions' not in req_obj: raise InvalidRuleException("Missing 'conditions'") @@ -322,16 +286,19 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None else: raise InvalidRuleException("Unknown rule template: %s" % (rule_template,)) - if device: - conditions.append({ - 'kind': 'device', - 'profile_tag': device - }) - if 'actions' not in req_obj: raise InvalidRuleException("No actions found") actions = req_obj['actions'] + _check_actions(actions) + + return conditions, actions + + +def _check_actions(actions): + if not isinstance(actions, list): + raise InvalidRuleException("No actions found") + for a in actions: if a in ['notify', 'dont_notify', 'coalesce']: pass @@ -340,8 +307,6 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None else: raise InvalidRuleException("Unrecognised action") - return conditions, actions - def _add_empty_priority_class_arrays(d): for pc in PRIORITY_CLASS_MAP.keys(): @@ -349,17 +314,6 @@ def _add_empty_priority_class_arrays(d): return d -def _profile_tag_from_conditions(conditions): - """ - Given a list of conditions, return the profile tag of the - device rule if there is one - """ - for c in conditions: - if c['kind'] == 'device': - return c['profile_tag'] - return None - - def _filter_ruleset_with_path(ruleset, path): if path == []: raise UnrecognizedRequestError( @@ -393,29 +347,23 @@ def _filter_ruleset_with_path(ruleset, path): attr = path[0] if attr in the_rule: - return the_rule[attr] + # Make sure we return a JSON object as the attribute may be a + # JSON value. + return {attr: the_rule[attr]} else: raise UnrecognizedRequestError() def _priority_class_from_spec(spec): if spec['template'] not in PRIORITY_CLASS_MAP.keys(): - raise InvalidRuleException("Unknown template: %s" % (spec['kind'])) + raise InvalidRuleException("Unknown template: %s" % (spec['template'])) pc = PRIORITY_CLASS_MAP[spec['template']] - if spec['scope'] == 'device': - pc += len(PRIORITY_CLASS_MAP) - return pc def _priority_class_to_template_name(pc): - if pc > PRIORITY_CLASS_MAP['override']: - # per-device - prio_class_index = pc - len(PRIORITY_CLASS_MAP) - return PRIORITY_CLASS_INVERSE_MAP[prio_class_index] - else: - return PRIORITY_CLASS_INVERSE_MAP[pc] + return PRIORITY_CLASS_INVERSE_MAP[pc] def _rule_to_template(rule): @@ -445,23 +393,12 @@ def _rule_to_template(rule): return templaterule -def _strip_device_condition(rule): - for i, c in enumerate(rule['conditions']): - if c['kind'] == 'device': - del rule['conditions'][i] - return rule - - def _namespaced_rule_id_from_spec(spec): return _namespaced_rule_id(spec, spec['rule_id']) def _namespaced_rule_id(spec, rule_id): - if spec['scope'] == 'global': - scope = 'global' - else: - scope = 'device/%s' % (spec['profile_tag']) - return "%s/%s/%s" % (scope, spec['template'], rule_id) + return "global/%s/%s" % (spec['template'], rule_id) def _rule_id_from_namespaced(in_rule_id): diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 5547f1b112..4c662e6e3c 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -45,7 +45,7 @@ class PusherRestServlet(ClientV1RestServlet): ) defer.returnValue((200, {})) - reqd = ['profile_tag', 'kind', 'app_id', 'app_display_name', + reqd = ['kind', 'app_id', 'app_display_name', 'device_display_name', 'pushkey', 'lang', 'data'] missing = [] for i in reqd: @@ -73,14 +73,14 @@ class PusherRestServlet(ClientV1RestServlet): yield pusher_pool.add_pusher( user_id=user.to_string(), access_token=requester.access_token_id, - profile_tag=content['profile_tag'], kind=content['kind'], app_id=content['app_id'], app_display_name=content['app_display_name'], device_display_name=content['device_display_name'], pushkey=content['pushkey'], lang=content['lang'], - data=content['data'] + data=content['data'], + profile_tag=content.get('profile_tag', ""), ) except PusherConfigException as pce: raise SynapseError(400, "Config Error: " + pce.message, diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 81bfe377bd..f5ed4f7302 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -63,24 +63,12 @@ class RoomCreateRestServlet(ClientV1RestServlet): def on_POST(self, request): requester = yield self.auth.get_user_by_req(request) - room_config = self.get_room_config(request) - info = yield self.make_room( - room_config, - requester.user, - None, - ) - room_config.update(info) - defer.returnValue((200, info)) - - @defer.inlineCallbacks - def make_room(self, room_config, auth_user, room_id): handler = self.handlers.room_creation_handler info = yield handler.create_room( - user_id=auth_user.to_string(), - room_id=room_id, - config=room_config + requester, self.get_room_config(request) ) - defer.returnValue(info) + + defer.returnValue((200, info)) def get_room_config(self, request): try: @@ -162,11 +150,22 @@ class RoomStateEventRestServlet(ClientV1RestServlet): event_dict["state_key"] = state_key msg_handler = self.handlers.message_handler - yield msg_handler.create_and_send_event( - event_dict, token_id=requester.access_token_id, txn_id=txn_id, + event, context = yield msg_handler.create_event( + event_dict, + token_id=requester.access_token_id, + txn_id=txn_id, ) - defer.returnValue((200, {})) + if event_type == EventTypes.Member: + yield self.handlers.room_member_handler.send_membership_event( + event, + context, + is_guest=requester.is_guest, + ) + else: + yield msg_handler.send_nonmember_event(event, context) + + defer.returnValue((200, {"event_id": event.event_id})) # TODO: Needs unit testing for generic events + feedback @@ -183,7 +182,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet): content = _parse_json(request) msg_handler = self.handlers.message_handler - event = yield msg_handler.create_and_send_event( + event = yield msg_handler.create_and_send_nonmember_event( { "type": event_type, "content": content, @@ -229,46 +228,37 @@ class JoinRoomAliasServlet(ClientV1RestServlet): allow_guest=True, ) - # the identifier could be a room alias or a room id. Try one then the - # other if it fails to parse, without swallowing other valid - # SynapseErrors. - - identifier = None - is_room_alias = False try: - identifier = RoomAlias.from_string(room_identifier) - is_room_alias = True - except SynapseError: - identifier = RoomID.from_string(room_identifier) - - # TODO: Support for specifying the home server to join with? - - if is_room_alias: + content = _parse_json(request) + except: + # Turns out we used to ignore the body entirely, and some clients + # cheekily send invalid bodies. + content = {} + + if RoomID.is_valid(room_identifier): + room_id = room_identifier + remote_room_hosts = None + elif RoomAlias.is_valid(room_identifier): handler = self.handlers.room_member_handler - ret_dict = yield handler.join_room_alias( - requester.user, - identifier, - ) - defer.returnValue((200, ret_dict)) - else: # room id - msg_handler = self.handlers.message_handler - content = {"membership": Membership.JOIN} - if requester.is_guest: - content["kind"] = "guest" - yield msg_handler.create_and_send_event( - { - "type": EventTypes.Member, - "content": content, - "room_id": identifier.to_string(), - "sender": requester.user.to_string(), - "state_key": requester.user.to_string(), - }, - token_id=requester.access_token_id, - txn_id=txn_id, - is_guest=requester.is_guest, - ) + room_alias = RoomAlias.from_string(room_identifier) + room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias) + room_id = room_id.to_string() + else: + raise SynapseError(400, "%s was not legal room ID or room alias" % ( + room_identifier, + )) - defer.returnValue((200, {"room_id": identifier.to_string()})) + yield self.handlers.room_member_handler.update_membership( + requester=requester, + target=requester.user, + room_id=room_id, + action="join", + txn_id=txn_id, + remote_room_hosts=remote_room_hosts, + third_party_signed=content.get("third_party_signed", None), + ) + + defer.returnValue((200, {"room_id": room_id})) @defer.inlineCallbacks def on_PUT(self, request, room_identifier, txn_id): @@ -316,18 +306,6 @@ class RoomMemberListRestServlet(ClientV1RestServlet): if event["type"] != EventTypes.Member: continue chunk.append(event) - # FIXME: should probably be state_key here, not user_id - target_user = UserID.from_string(event["user_id"]) - # Presence is an optional cache; don't fail if we can't fetch it - try: - presence_handler = self.handlers.presence_handler - presence_state = yield presence_handler.get_state( - target_user=target_user, - auth_user=requester.user, - ) - event["content"].update(presence_state) - except: - pass defer.returnValue((200, { "chunk": chunk @@ -454,7 +432,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet): }: raise AuthError(403, "Guest access not allowed") - content = _parse_json(request) + try: + content = _parse_json(request) + except: + # Turns out we used to ignore the body entirely, and some clients + # cheekily send invalid bodies. + content = {} if membership_action == "invite" and self._has_3pid_invite_keys(content): yield self.handlers.room_member_handler.do_3pid_invite( @@ -463,7 +446,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): content["medium"], content["address"], content["id_server"], - requester.access_token_id, + requester, txn_id ) defer.returnValue((200, {})) @@ -481,6 +464,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): room_id=room_id, action=membership_action, txn_id=txn_id, + third_party_signed=content.get("third_party_signed", None), ) defer.returnValue((200, {})) @@ -519,7 +503,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): content = _parse_json(request) msg_handler = self.handlers.message_handler - event = yield msg_handler.create_and_send_event( + event = yield msg_handler.create_and_send_nonmember_event( { "type": EventTypes.Redaction, "content": content, @@ -553,6 +537,10 @@ class RoomTypingRestServlet(ClientV1RestServlet): "/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$" ) + def __init__(self, hs): + super(RoomTypingRestServlet, self).__init__(hs) + self.presence_handler = hs.get_handlers().presence_handler + @defer.inlineCallbacks def on_PUT(self, request, room_id, user_id): requester = yield self.auth.get_user_by_req(request) @@ -564,6 +552,8 @@ class RoomTypingRestServlet(ClientV1RestServlet): typing_handler = self.handlers.typing_notification_handler + yield self.presence_handler.bump_presence_active_time(requester.user) + if content["typing"]: yield typing_handler.started_typing( target_user=target_user, diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index ff71c40b43..78181b7b18 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.api.constants import LoginType from synapse.api.errors import SynapseError from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX +from synapse.http.server import finish_request from synapse.http.servlet import RestServlet from ._base import client_v2_patterns @@ -130,7 +131,7 @@ class AuthRestServlet(RestServlet): request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) request.write(html_bytes) - request.finish() + finish_request(request) defer.returnValue(None) else: raise SynapseError(404, "Unknown auth stage type") @@ -176,7 +177,7 @@ class AuthRestServlet(RestServlet): request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) request.write(html_bytes) - request.finish() + finish_request(request) defer.returnValue(None) else: diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index eb4b369a3d..b831d8c95e 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -37,6 +37,7 @@ class ReceiptRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.receipts_handler = hs.get_handlers().receipts_handler + self.presence_handler = hs.get_handlers().presence_handler @defer.inlineCallbacks def on_POST(self, request, room_id, receipt_type, event_id): @@ -45,6 +46,8 @@ class ReceiptRestServlet(RestServlet): if receipt_type != "m.read": raise SynapseError(400, "Receipt type must be 'm.read'") + yield self.presence_handler.bump_presence_active_time(requester.user) + yield self.receipts_handler.received_client_receipt( room_id, receipt_type, diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index accbc6cfac..de4a020ad4 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -25,6 +25,7 @@ from synapse.events.utils import ( ) from synapse.api.filtering import FilterCollection, DEFAULT_FILTER_COLLECTION from synapse.api.errors import SynapseError +from synapse.api.constants import PresenceState from ._base import client_v2_patterns import copy @@ -82,6 +83,7 @@ class SyncRestServlet(RestServlet): self.sync_handler = hs.get_handlers().sync_handler self.clock = hs.get_clock() self.filtering = hs.get_filtering() + self.presence_handler = hs.get_handlers().presence_handler @defer.inlineCallbacks def on_GET(self, request): @@ -139,17 +141,19 @@ class SyncRestServlet(RestServlet): else: since_token = None - if set_presence == "online": - yield self.event_stream_handler.started_stream(user) + affect_presence = set_presence != PresenceState.OFFLINE - try: + if affect_presence: + yield self.presence_handler.set_state(user, {"presence": set_presence}) + + context = yield self.presence_handler.user_syncing( + user.to_string(), affect_presence=affect_presence, + ) + with context: sync_result = yield self.sync_handler.wait_for_sync_for_user( sync_config, since_token=since_token, timeout=timeout, full_state=full_state ) - finally: - if set_presence == "online": - self.event_stream_handler.stopped_stream(user) time_now = self.clock.time_msec() diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py index dcf3eaee1f..d9fc045fc6 100644 --- a/synapse/rest/media/v0/content_repository.py +++ b/synapse/rest/media/v0/content_repository.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.http.server import respond_with_json_bytes +from synapse.http.server import respond_with_json_bytes, finish_request from synapse.util.stringutils import random_string from synapse.api.errors import ( @@ -144,7 +144,7 @@ class ContentRepoResource(resource.Resource): # after the file has been sent, clean up and finish the request def cbFinished(ignored): f.close() - request.finish() + finish_request(request) d.addCallback(cbFinished) else: respond_with_json_bytes( diff --git a/synapse/rest/media/v1/base_resource.py b/synapse/rest/media/v1/base_resource.py index 58d56ec7a4..58ef91c0b8 100644 --- a/synapse/rest/media/v1/base_resource.py +++ b/synapse/rest/media/v1/base_resource.py @@ -16,7 +16,7 @@ from .thumbnailer import Thumbnailer from synapse.http.matrixfederationclient import MatrixFederationHttpClient -from synapse.http.server import respond_with_json +from synapse.http.server import respond_with_json, finish_request from synapse.util.stringutils import random_string from synapse.api.errors import ( cs_error, Codes, SynapseError @@ -238,7 +238,7 @@ class BaseMediaResource(Resource): with open(file_path, "rb") as f: yield FileSender().beginFileTransfer(f, request) - request.finish() + finish_request(request) else: self._respond_404(request) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 5a9e7720d9..f257721ea3 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -20,7 +20,7 @@ from .appservice import ( from ._base import Cache from .directory import DirectoryStore from .events import EventsStore -from .presence import PresenceStore +from .presence import PresenceStore, UserPresenceState from .profile import ProfileStore from .registration import RegistrationStore from .room import RoomStore @@ -47,6 +47,7 @@ from .account_data import AccountDataStore from util.id_generators import IdGenerator, StreamIdGenerator +from synapse.api.constants import PresenceState from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -110,16 +111,19 @@ class DataStore(RoomMemberStore, RoomStore, self._account_data_id_gen = StreamIdGenerator( db_conn, "account_data_max_stream_id", "stream_id" ) + self._presence_id_gen = StreamIdGenerator( + db_conn, "presence_stream", "stream_id" + ) - self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) - self._state_groups_id_gen = IdGenerator("state_groups", "id", self) - self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) - self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self) - self._pushers_id_gen = IdGenerator("pushers", "id", self) - self._push_rule_id_gen = IdGenerator("push_rules", "id", self) - self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) + self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") + self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") + self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") + self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") + self._pushers_id_gen = IdGenerator(db_conn, "pushers", "id") + self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") + self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") - events_max = self._stream_id_gen.get_max_token(None) + events_max = self._stream_id_gen.get_max_token() event_cache_prefill, min_event_val = self._get_cache_dict( db_conn, "events", entity_column="room_id", @@ -135,13 +139,31 @@ class DataStore(RoomMemberStore, RoomStore, "MembershipStreamChangeCache", events_max, ) - account_max = self._account_data_id_gen.get_max_token(None) + account_max = self._account_data_id_gen.get_max_token() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max, ) + self.__presence_on_startup = self._get_active_presence(db_conn) + + presence_cache_prefill, min_presence_val = self._get_cache_dict( + db_conn, "presence_stream", + entity_column="user_id", + stream_column="stream_id", + max_value=self._presence_id_gen.get_max_token(), + ) + self.presence_stream_cache = StreamChangeCache( + "PresenceStreamChangeCache", min_presence_val, + prefilled_cache=presence_cache_prefill + ) + super(DataStore, self).__init__(hs) + def take_presence_startup_info(self): + active_on_startup = self.__presence_on_startup + self.__presence_on_startup = None + return active_on_startup + def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value): # Fetch a mapping of room_id -> max stream position for "recent" rooms. # It doesn't really matter how many we get, the StreamChangeCache will @@ -161,6 +183,7 @@ class DataStore(RoomMemberStore, RoomStore, txn = db_conn.cursor() txn.execute(sql, (int(max_value),)) rows = txn.fetchall() + txn.close() cache = { row[0]: int(row[1]) @@ -174,6 +197,28 @@ class DataStore(RoomMemberStore, RoomStore, return cache, min_val + def _get_active_presence(self, db_conn): + """Fetch non-offline presence from the database so that we can register + the appropriate time outs. + """ + + sql = ( + "SELECT user_id, state, last_active_ts, last_federation_update_ts," + " last_user_sync_ts, status_msg, currently_active FROM presence_stream" + " WHERE state != ?" + ) + sql = self.database_engine.convert_param_style(sql) + + txn = db_conn.cursor() + txn.execute(sql, (PresenceState.OFFLINE,)) + rows = self.cursor_to_dict(txn) + txn.close() + + for row in rows: + row["currently_active"] = bool(row["currently_active"]) + + return [UserPresenceState(**row) for row in rows] + @defer.inlineCallbacks def insert_client_ip(self, user, access_token, ip, user_agent): now = int(self._clock.time_msec()) diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index b8387fc500..faddefe219 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -83,8 +83,40 @@ class AccountDataStore(SQLBaseStore): "get_account_data_for_room", get_account_data_for_room_txn ) - def get_updated_account_data_for_user(self, user_id, stream_id, room_ids=None): - """Get all the client account_data for a that's changed. + def get_all_updated_account_data(self, last_global_id, last_room_id, + current_id, limit): + """Get all the client account_data that has changed on the server + Args: + last_global_id(int): The position to fetch from for top level data + last_room_id(int): The position to fetch from for per room data + current_id(int): The position to fetch up to. + Returns: + A deferred pair of lists of tuples of stream_id int, user_id string, + room_id string, type string, and content string. + """ + def get_updated_account_data_txn(txn): + sql = ( + "SELECT stream_id, user_id, account_data_type, content" + " FROM account_data WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_global_id, current_id, limit)) + global_results = txn.fetchall() + + sql = ( + "SELECT stream_id, user_id, room_id, account_data_type, content" + " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_room_id, current_id, limit)) + room_results = txn.fetchall() + return (global_results, room_results) + return self.runInteraction( + "get_all_updated_account_data_txn", get_updated_account_data_txn + ) + + def get_updated_account_data_for_user(self, user_id, stream_id): + """Get all the client account_data for a that's changed for a user Args: user_id(str): The user to get the account_data for. @@ -163,12 +195,12 @@ class AccountDataStore(SQLBaseStore): ) self._update_max_stream_id(txn, next_id) - with (yield self._account_data_id_gen.get_next(self)) as next_id: + with self._account_data_id_gen.get_next() as next_id: yield self.runInteraction( "add_room_account_data", add_account_data_txn, next_id ) - result = yield self._account_data_id_gen.get_max_token(self) + result = self._account_data_id_gen.get_max_token() defer.returnValue(result) @defer.inlineCallbacks @@ -202,12 +234,12 @@ class AccountDataStore(SQLBaseStore): ) self._update_max_stream_id(txn, next_id) - with (yield self._account_data_id_gen.get_next(self)) as next_id: + with self._account_data_id_gen.get_next() as next_id: yield self.runInteraction( "add_user_account_data", add_account_data_txn, next_id ) - result = yield self._account_data_id_gen.get_max_token(self) + result = self._account_data_id_gen.get_max_token() defer.returnValue(result) def _update_max_stream_id(self, txn, next_id): diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index ce2c794025..3489315e0d 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -114,10 +114,10 @@ class EventFederationStore(SQLBaseStore): retcol="event_id", ) - def get_latest_events_in_room(self, room_id): + def get_latest_event_ids_and_hashes_in_room(self, room_id): return self.runInteraction( - "get_latest_events_in_room", - self._get_latest_events_in_room, + "get_latest_event_ids_and_hashes_in_room", + self._get_latest_event_ids_and_hashes_in_room, room_id, ) @@ -132,7 +132,7 @@ class EventFederationStore(SQLBaseStore): desc="get_latest_event_ids_in_room", ) - def _get_latest_events_in_room(self, txn, room_id): + def _get_latest_event_ids_and_hashes_in_room(self, txn, room_id): sql = ( "SELECT e.event_id, e.depth FROM events as e " "INNER JOIN event_forward_extremities as f " diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index d77a817682..5820539a92 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -27,15 +27,14 @@ class EventPushActionsStore(SQLBaseStore): def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples): """ :param event: the event set actions for - :param tuples: list of tuples of (user_id, profile_tag, actions) + :param tuples: list of tuples of (user_id, actions) """ values = [] - for uid, profile_tag, actions in tuples: + for uid, actions in tuples: values.append({ 'room_id': event.room_id, 'event_id': event.event_id, 'user_id': uid, - 'profile_tag': profile_tag, 'actions': json.dumps(actions), 'stream_ordering': event.internal_metadata.stream_ordering, 'topological_ordering': event.depth, @@ -43,7 +42,7 @@ class EventPushActionsStore(SQLBaseStore): 'highlight': 1 if _action_has_highlight(actions) else 0, }) - for uid, _, __ in tuples: + for uid, __ in tuples: txn.call_after( self.get_unread_event_push_actions_by_room_for_user.invalidate_many, (event.room_id, uid) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 3a5c6ee4b1..60936500d8 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -75,8 +75,8 @@ class EventsStore(SQLBaseStore): yield stream_orderings stream_ordering_manager = stream_ordering_manager() else: - stream_ordering_manager = yield self._stream_id_gen.get_next_mult( - self, len(events_and_contexts) + stream_ordering_manager = self._stream_id_gen.get_next_mult( + len(events_and_contexts) ) with stream_ordering_manager as stream_orderings: @@ -109,7 +109,7 @@ class EventsStore(SQLBaseStore): stream_ordering = self.min_stream_token if stream_ordering is None: - stream_ordering_manager = yield self._stream_id_gen.get_next(self) + stream_ordering_manager = self._stream_id_gen.get_next() else: @contextmanager def stream_ordering_manager(): @@ -131,7 +131,7 @@ class EventsStore(SQLBaseStore): except _RollbackButIsFineException: pass - max_persisted_id = yield self._stream_id_gen.get_max_token(self) + max_persisted_id = yield self._stream_id_gen.get_max_token() defer.returnValue((stream_ordering, max_persisted_id)) @defer.inlineCallbacks @@ -1064,3 +1064,48 @@ class EventsStore(SQLBaseStore): yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME) defer.returnValue(result) + + def get_current_backfill_token(self): + """The current minimum token that backfilled events have reached""" + + # TODO: Fix race with the persit_event txn by using one of the + # stream id managers + return -self.min_stream_token + + def get_all_new_events(self, last_backfill_id, last_forward_id, + current_backfill_id, current_forward_id, limit): + """Get all the new events that have arrived at the server either as + new events or as backfilled events""" + def get_all_new_events_txn(txn): + sql = ( + "SELECT e.stream_ordering, ej.internal_metadata, ej.json" + " FROM events as e" + " JOIN event_json as ej" + " ON e.event_id = ej.event_id AND e.room_id = ej.room_id" + " WHERE ? < e.stream_ordering AND e.stream_ordering <= ?" + " ORDER BY e.stream_ordering ASC" + " LIMIT ?" + ) + if last_forward_id != current_forward_id: + txn.execute(sql, (last_forward_id, current_forward_id, limit)) + new_forward_events = txn.fetchall() + else: + new_forward_events = [] + + sql = ( + "SELECT -e.stream_ordering, ej.internal_metadata, ej.json" + " FROM events as e" + " JOIN event_json as ej" + " ON e.event_id = ej.event_id AND e.room_id = ej.room_id" + " WHERE ? > e.stream_ordering AND e.stream_ordering >= ?" + " ORDER BY e.stream_ordering DESC" + " LIMIT ?" + ) + if last_backfill_id != current_backfill_id: + txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit)) + new_backfill_events = txn.fetchall() + else: + new_backfill_events = [] + + return (new_forward_events, new_backfill_events) + return self.runInteraction("get_all_new_events", get_all_new_events_txn) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 850736c85e..0fd5d497ab 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 29 +SCHEMA_VERSION = 30 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index ef525f34c5..4cec31e316 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -14,73 +14,148 @@ # limitations under the License. from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached, cachedList +from synapse.api.constants import PresenceState +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from collections import namedtuple from twisted.internet import defer +class UserPresenceState(namedtuple("UserPresenceState", + ("user_id", "state", "last_active_ts", + "last_federation_update_ts", "last_user_sync_ts", + "status_msg", "currently_active"))): + """Represents the current presence state of the user. + + user_id (str) + last_active (int): Time in msec that the user last interacted with server. + last_federation_update (int): Time in msec since either a) we sent a presence + update to other servers or b) we received a presence update, depending + on if is a local user or not. + last_user_sync (int): Time in msec that the user last *completed* a sync + (or event stream). + status_msg (str): User set status message. + """ + + def copy_and_replace(self, **kwargs): + return self._replace(**kwargs) + + @classmethod + def default(cls, user_id): + """Returns a default presence state. + """ + return cls( + user_id=user_id, + state=PresenceState.OFFLINE, + last_active_ts=0, + last_federation_update_ts=0, + last_user_sync_ts=0, + status_msg=None, + currently_active=False, + ) + + class PresenceStore(SQLBaseStore): - def create_presence(self, user_localpart): - res = self._simple_insert( - table="presence", - values={"user_id": user_localpart}, - desc="create_presence", + @defer.inlineCallbacks + def update_presence(self, presence_states): + stream_ordering_manager = self._presence_id_gen.get_next_mult( + len(presence_states) ) - self.get_presence_state.invalidate((user_localpart,)) - return res + with stream_ordering_manager as stream_orderings: + yield self.runInteraction( + "update_presence", + self._update_presence_txn, stream_orderings, presence_states, + ) - def has_presence_state(self, user_localpart): - return self._simple_select_one( - table="presence", - keyvalues={"user_id": user_localpart}, - retcols=["user_id"], - allow_none=True, - desc="has_presence_state", + defer.returnValue((stream_orderings[-1], self._presence_id_gen.get_max_token())) + + def _update_presence_txn(self, txn, stream_orderings, presence_states): + for stream_id, state in zip(stream_orderings, presence_states): + txn.call_after( + self.presence_stream_cache.entity_has_changed, + state.user_id, stream_id, + ) + + # Actually insert new rows + self._simple_insert_many_txn( + txn, + table="presence_stream", + values=[ + { + "stream_id": stream_id, + "user_id": state.user_id, + "state": state.state, + "last_active_ts": state.last_active_ts, + "last_federation_update_ts": state.last_federation_update_ts, + "last_user_sync_ts": state.last_user_sync_ts, + "status_msg": state.status_msg, + "currently_active": state.currently_active, + } + for state in presence_states + ], ) - @cached(max_entries=2000) - def get_presence_state(self, user_localpart): - return self._simple_select_one( - table="presence", - keyvalues={"user_id": user_localpart}, - retcols=["state", "status_msg", "mtime"], - desc="get_presence_state", + # Delete old rows to stop database from getting really big + sql = ( + "DELETE FROM presence_stream WHERE" + " stream_id < ?" + " AND user_id IN (%s)" ) - @cachedList(get_presence_state.cache, list_name="user_localparts", - inlineCallbacks=True) - def get_presence_states(self, user_localparts): - rows = yield self._simple_select_many_batch( - table="presence", - column="user_id", - iterable=user_localparts, - retcols=("user_id", "state", "status_msg", "mtime",), - desc="get_presence_states", + batches = ( + presence_states[i:i + 50] + for i in xrange(0, len(presence_states), 50) ) + for states in batches: + args = [stream_id] + args.extend(s.user_id for s in states) + txn.execute( + sql % (",".join("?" for _ in states),), + args + ) + + def get_all_presence_updates(self, last_id, current_id): + def get_all_presence_updates_txn(txn): + sql = ( + "SELECT stream_id, user_id, state, last_active_ts," + " last_federation_update_ts, last_user_sync_ts, status_msg," + " currently_active" + " FROM presence_stream" + " WHERE ? < stream_id AND stream_id <= ?" + ) + txn.execute(sql, (last_id, current_id)) + return txn.fetchall() - defer.returnValue({ - row["user_id"]: { - "state": row["state"], - "status_msg": row["status_msg"], - "mtime": row["mtime"], - } - for row in rows - }) + return self.runInteraction( + "get_all_presence_updates", get_all_presence_updates_txn + ) @defer.inlineCallbacks - def set_presence_state(self, user_localpart, new_state): - res = yield self._simple_update_one( - table="presence", - keyvalues={"user_id": user_localpart}, - updatevalues={"state": new_state["state"], - "status_msg": new_state["status_msg"], - "mtime": self._clock.time_msec()}, - desc="set_presence_state", + def get_presence_for_users(self, user_ids): + rows = yield self._simple_select_many_batch( + table="presence_stream", + column="user_id", + iterable=user_ids, + keyvalues={}, + retcols=( + "user_id", + "state", + "last_active_ts", + "last_federation_update_ts", + "last_user_sync_ts", + "status_msg", + "currently_active", + ), ) - self.get_presence_state.invalidate((user_localpart,)) - defer.returnValue(res) + for row in rows: + row["currently_active"] = bool(row["currently_active"]) + + defer.returnValue([UserPresenceState(**row) for row in rows]) + + def get_current_presence_token(self): + return self._presence_id_gen.get_max_token() def allow_presence_visible(self, observed_localpart, observer_userid): return self._simple_insert( @@ -128,6 +203,7 @@ class PresenceStore(SQLBaseStore): desc="set_presence_list_accepted", ) self.get_presence_list_accepted.invalidate((observer_localpart,)) + self.get_presence_list_observers_accepted.invalidate((observed_userid,)) defer.returnValue(result) def get_presence_list(self, observer_localpart, accepted=None): @@ -154,6 +230,19 @@ class PresenceStore(SQLBaseStore): desc="get_presence_list_accepted", ) + @cachedInlineCallbacks() + def get_presence_list_observers_accepted(self, observed_userid): + user_localparts = yield self._simple_select_onecol( + table="presence_list", + keyvalues={"observed_user_id": observed_userid, "accepted": True}, + retcol="user_id", + desc="get_presence_list_accepted", + ) + + defer.returnValue([ + "@%s:%s" % (u, self.hs.hostname,) for u in user_localparts + ]) + @defer.inlineCallbacks def del_presence_list(self, observer_localpart, observed_userid): yield self._simple_delete_one( @@ -163,3 +252,4 @@ class PresenceStore(SQLBaseStore): desc="del_presence_list", ) self.get_presence_list_accepted.invalidate((observer_localpart,)) + self.get_presence_list_observers_accepted.invalidate((observed_userid,)) diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index f9a48171ba..56e69495b1 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -99,38 +99,36 @@ class PushRuleStore(SQLBaseStore): results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled'] defer.returnValue(results) - @defer.inlineCallbacks - def add_push_rule(self, before, after, **kwargs): - vals = kwargs - if 'conditions' in vals: - vals['conditions'] = json.dumps(vals['conditions']) - if 'actions' in vals: - vals['actions'] = json.dumps(vals['actions']) - - # we could check the rest of the keys are valid column names - # but sqlite will do that anyway so I think it's just pointless. - vals.pop("id", None) + def add_push_rule( + self, user_id, rule_id, priority_class, conditions, actions, + before=None, after=None + ): + conditions_json = json.dumps(conditions) + actions_json = json.dumps(actions) if before or after: - ret = yield self.runInteraction( + return self.runInteraction( "_add_push_rule_relative_txn", self._add_push_rule_relative_txn, - before=before, - after=after, - **vals + user_id, rule_id, priority_class, + conditions_json, actions_json, before, after, ) - defer.returnValue(ret) else: - ret = yield self.runInteraction( + return self.runInteraction( "_add_push_rule_highest_priority_txn", self._add_push_rule_highest_priority_txn, - **vals + user_id, rule_id, priority_class, + conditions_json, actions_json, ) - defer.returnValue(ret) - def _add_push_rule_relative_txn(self, txn, user_id, **kwargs): - after = kwargs.pop("after", None) - before = kwargs.pop("before", None) + def _add_push_rule_relative_txn( + self, txn, user_id, rule_id, priority_class, + conditions_json, actions_json, before, after + ): + # Lock the table since otherwise we'll have annoying races between the + # SELECT here and the UPSERT below. + self.database_engine.lock_table(txn, "push_rules") + relative_to_rule = before or after res = self._simple_select_one_txn( @@ -149,69 +147,45 @@ class PushRuleStore(SQLBaseStore): "before/after rule not found: %s" % (relative_to_rule,) ) - priority_class = res["priority_class"] + base_priority_class = res["priority_class"] base_rule_priority = res["priority"] - if 'priority_class' in kwargs and kwargs['priority_class'] != priority_class: + if base_priority_class != priority_class: raise InconsistentRuleException( "Given priority class does not match class of relative rule" ) - new_rule = kwargs - new_rule.pop("before", None) - new_rule.pop("after", None) - new_rule['priority_class'] = priority_class - new_rule['user_name'] = user_id - new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn) - - # check if the priority before/after is free - new_rule_priority = base_rule_priority - if after: - new_rule_priority -= 1 + if before: + # Higher priority rules are executed first, So adding a rule before + # a rule means giving it a higher priority than that rule. + new_rule_priority = base_rule_priority + 1 else: - new_rule_priority += 1 - - new_rule['priority'] = new_rule_priority + # We increment the priority of the existing rules to make space for + # the new rule. Therefore if we want this rule to appear after + # an existing rule we give it the priority of the existing rule, + # and then increment the priority of the existing rule. + new_rule_priority = base_rule_priority sql = ( - "SELECT COUNT(*) FROM push_rules" - " WHERE user_name = ? AND priority_class = ? AND priority = ?" + "UPDATE push_rules SET priority = priority + 1" + " WHERE user_name = ? AND priority_class = ? AND priority >= ?" ) - txn.execute(sql, (user_id, priority_class, new_rule_priority)) - res = txn.fetchall() - num_conflicting = res[0][0] - - # if there are conflicting rules, bump everything - if num_conflicting: - sql = "UPDATE push_rules SET priority = priority " - if after: - sql += "-1" - else: - sql += "+1" - sql += " WHERE user_name = ? AND priority_class = ? AND priority " - if after: - sql += "<= ?" - else: - sql += ">= ?" - txn.execute(sql, (user_id, priority_class, new_rule_priority)) - - txn.call_after( - self.get_push_rules_for_user.invalidate, (user_id,) - ) + txn.execute(sql, (user_id, priority_class, new_rule_priority)) - txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, (user_id,) + self._upsert_push_rule_txn( + txn, user_id, rule_id, priority_class, new_rule_priority, + conditions_json, actions_json, ) - self._simple_insert_txn( - txn, - table="push_rules", - values=new_rule, - ) + def _add_push_rule_highest_priority_txn( + self, txn, user_id, rule_id, priority_class, + conditions_json, actions_json + ): + # Lock the table since otherwise we'll have annoying races between the + # SELECT here and the UPSERT below. + self.database_engine.lock_table(txn, "push_rules") - def _add_push_rule_highest_priority_txn(self, txn, user_id, - priority_class, **kwargs): # find the highest priority rule in that class sql = ( "SELECT COUNT(*), MAX(priority) FROM push_rules" @@ -225,12 +199,48 @@ class PushRuleStore(SQLBaseStore): if how_many > 0: new_prio = highest_prio + 1 - # and insert the new rule - new_rule = kwargs - new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn) - new_rule['user_name'] = user_id - new_rule['priority_class'] = priority_class - new_rule['priority'] = new_prio + self._upsert_push_rule_txn( + txn, + user_id, rule_id, priority_class, new_prio, + conditions_json, actions_json, + ) + + def _upsert_push_rule_txn( + self, txn, user_id, rule_id, priority_class, + priority, conditions_json, actions_json + ): + """Specialised version of _simple_upsert_txn that picks a push_rule_id + using the _push_rule_id_gen if it needs to insert the rule. It assumes + that the "push_rules" table is locked""" + + sql = ( + "UPDATE push_rules" + " SET priority_class = ?, priority = ?, conditions = ?, actions = ?" + " WHERE user_name = ? AND rule_id = ?" + ) + + txn.execute(sql, ( + priority_class, priority, conditions_json, actions_json, + user_id, rule_id, + )) + + if txn.rowcount == 0: + # We didn't update a row with the given rule_id so insert one + push_rule_id = self._push_rule_id_gen.get_next() + + self._simple_insert_txn( + txn, + table="push_rules", + values={ + "id": push_rule_id, + "user_name": user_id, + "rule_id": rule_id, + "priority_class": priority_class, + "priority": priority, + "conditions": conditions_json, + "actions": actions_json, + }, + ) txn.call_after( self.get_push_rules_for_user.invalidate, (user_id,) @@ -239,12 +249,6 @@ class PushRuleStore(SQLBaseStore): self.get_push_rules_enabled_for_user.invalidate, (user_id,) ) - self._simple_insert_txn( - txn, - table="push_rules", - values=new_rule, - ) - @defer.inlineCallbacks def delete_push_rule(self, user_id, rule_id): """ @@ -275,7 +279,7 @@ class PushRuleStore(SQLBaseStore): defer.returnValue(ret) def _set_push_rule_enabled_txn(self, txn, user_id, rule_id, enabled): - new_id = self._push_rules_enable_id_gen.get_next_txn(txn) + new_id = self._push_rules_enable_id_gen.get_next() self._simple_upsert_txn( txn, "push_rules_enable", @@ -290,6 +294,31 @@ class PushRuleStore(SQLBaseStore): self.get_push_rules_enabled_for_user.invalidate, (user_id,) ) + def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule): + actions_json = json.dumps(actions) + + def set_push_rule_actions_txn(txn): + if is_default_rule: + # Add a dummy rule to the rules table with the user specified + # actions. + priority_class = -1 + priority = 1 + self._upsert_push_rule_txn( + txn, user_id, rule_id, priority_class, priority, + "[]", actions_json + ) + else: + self._simple_update_one_txn( + txn, + "push_rules", + {'user_name': user_id, 'rule_id': rule_id}, + {'actions': actions_json}, + ) + + return self.runInteraction( + "set_push_rule_actions", set_push_rule_actions_txn, + ) + class RuleNotFoundException(Exception): pass diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 8ec706178a..7693ab9082 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -80,11 +80,11 @@ class PusherStore(SQLBaseStore): defer.returnValue(rows) @defer.inlineCallbacks - def add_pusher(self, user_id, access_token, profile_tag, kind, app_id, + def add_pusher(self, user_id, access_token, kind, app_id, app_display_name, device_display_name, - pushkey, pushkey_ts, lang, data): + pushkey, pushkey_ts, lang, data, profile_tag=""): try: - next_id = yield self._pushers_id_gen.get_next() + next_id = self._pushers_id_gen.get_next() yield self._simple_upsert( "pushers", dict( @@ -95,12 +95,12 @@ class PusherStore(SQLBaseStore): dict( access_token=access_token, kind=kind, - profile_tag=profile_tag, app_display_name=app_display_name, device_display_name=device_display_name, ts=pushkey_ts, lang=lang, data=encode_canonical_json(data), + profile_tag=profile_tag, ), insertion_values=dict( id=next_id, diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 4202a6b3dc..dbc074d6b5 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -31,7 +31,7 @@ class ReceiptsStore(SQLBaseStore): super(ReceiptsStore, self).__init__(hs) self._receipts_stream_cache = StreamChangeCache( - "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token(None) + "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token() ) @cached(num_args=2) @@ -222,7 +222,7 @@ class ReceiptsStore(SQLBaseStore): defer.returnValue(results) def get_max_receipt_stream_id(self): - return self._receipts_id_gen.get_max_token(self) + return self._receipts_id_gen.get_max_token() def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, user_id, event_id, data, stream_id): @@ -330,7 +330,7 @@ class ReceiptsStore(SQLBaseStore): "insert_receipt_conv", graph_to_linear ) - stream_id_manager = yield self._receipts_id_gen.get_next(self) + stream_id_manager = self._receipts_id_gen.get_next() with stream_id_manager as stream_id: have_persisted = yield self.runInteraction( "insert_linearized_receipt", @@ -347,7 +347,7 @@ class ReceiptsStore(SQLBaseStore): room_id, receipt_type, user_id, event_ids, data ) - max_persisted_id = yield self._stream_id_gen.get_max_token(self) + max_persisted_id = self._stream_id_gen.get_max_token() defer.returnValue((stream_id, max_persisted_id)) @@ -390,3 +390,19 @@ class ReceiptsStore(SQLBaseStore): "data": json.dumps(data), } ) + + def get_all_updated_receipts(self, last_id, current_id, limit): + def get_all_updated_receipts_txn(txn): + sql = ( + "SELECT stream_id, room_id, receipt_type, user_id, event_id, data" + " FROM receipts_linearized" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + " LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + + return txn.fetchall() + return self.runInteraction( + "get_all_updated_receipts", get_all_updated_receipts_txn + ) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 967c732bda..ad1157f979 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -40,7 +40,7 @@ class RegistrationStore(SQLBaseStore): Raises: StoreError if there was a problem adding this. """ - next_id = yield self._access_tokens_id_gen.get_next() + next_id = self._access_tokens_id_gen.get_next() yield self._simple_insert( "access_tokens", @@ -62,7 +62,7 @@ class RegistrationStore(SQLBaseStore): Raises: StoreError if there was a problem adding this. """ - next_id = yield self._refresh_tokens_id_gen.get_next() + next_id = self._refresh_tokens_id_gen.get_next() yield self._simple_insert( "refresh_tokens", @@ -99,7 +99,7 @@ class RegistrationStore(SQLBaseStore): def _register(self, txn, user_id, token, password_hash, was_guest, make_guest): now = int(self.clock.time()) - next_id = self._access_tokens_id_gen.get_next_txn(txn) + next_id = self._access_tokens_id_gen.get_next() try: if was_guest: @@ -387,3 +387,47 @@ class RegistrationStore(SQLBaseStore): "find_next_generated_user_id", _find_next_generated_user_id ))) + + @defer.inlineCallbacks + def get_3pid_guest_access_token(self, medium, address): + ret = yield self._simple_select_one( + "threepid_guest_access_tokens", + { + "medium": medium, + "address": address + }, + ["guest_access_token"], True, 'get_3pid_guest_access_token' + ) + if ret: + defer.returnValue(ret["guest_access_token"]) + defer.returnValue(None) + + @defer.inlineCallbacks + def save_or_get_3pid_guest_access_token( + self, medium, address, access_token, inviter_user_id + ): + """ + Gets the 3pid's guest access token if exists, else saves access_token. + + :param medium (str): Medium of the 3pid. Must be "email". + :param address (str): 3pid address. + :param access_token (str): The access token to persist if none is + already persisted. + :param inviter_user_id (str): User ID of the inviter. + :return (deferred str): Whichever access token is persisted at the end + of this function call. + """ + def insert(txn): + txn.execute( + "INSERT INTO threepid_guest_access_tokens " + "(medium, address, guest_access_token, first_inviter) " + "VALUES (?, ?, ?, ?)", + (medium, address, access_token, inviter_user_id) + ) + + try: + yield self.runInteraction("save_3pid_guest_access_token", insert) + defer.returnValue(access_token) + except self.database_engine.module.IntegrityError: + ret = yield self.get_3pid_guest_access_token(medium, address) + defer.returnValue(ret) diff --git a/synapse/storage/schema/delta/30/presence_stream.sql b/synapse/storage/schema/delta/30/presence_stream.sql new file mode 100644 index 0000000000..606bbb037d --- /dev/null +++ b/synapse/storage/schema/delta/30/presence_stream.sql @@ -0,0 +1,30 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + CREATE TABLE presence_stream( + stream_id BIGINT, + user_id TEXT, + state TEXT, + last_active_ts BIGINT, + last_federation_update_ts BIGINT, + last_user_sync_ts BIGINT, + status_msg TEXT, + currently_active BOOLEAN + ); + + CREATE INDEX presence_stream_id ON presence_stream(stream_id, user_id); + CREATE INDEX presence_stream_user_id ON presence_stream(user_id); + CREATE INDEX presence_stream_state ON presence_stream(state); diff --git a/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql b/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql new file mode 100644 index 0000000000..0dd2f1360c --- /dev/null +++ b/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql @@ -0,0 +1,24 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Stores guest account access tokens generated for unbound 3pids. +CREATE TABLE threepid_guest_access_tokens( + medium TEXT, -- The medium of the 3pid. Must be "email". + address TEXT, -- The 3pid address. + guest_access_token TEXT, -- The access token for a guest user for this 3pid. + first_inviter TEXT -- User ID of the first user to invite this 3pid to a room. +); + +CREATE UNIQUE INDEX threepid_guest_access_tokens_index ON threepid_guest_access_tokens(medium, address); diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 372b540002..8ed8a21b0a 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -83,7 +83,7 @@ class StateStore(SQLBaseStore): if event.is_state(): state_events[(event.type, event.state_key)] = event - state_group = self._state_groups_id_gen.get_next_txn(txn) + state_group = self._state_groups_id_gen.get_next() self._simple_insert_txn( txn, table="state_groups", diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index c236dafafb..8908d5b5da 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -531,7 +531,7 @@ class StreamStore(SQLBaseStore): @defer.inlineCallbacks def get_room_events_max_id(self, direction='f'): - token = yield self._stream_id_gen.get_max_token(self) + token = yield self._stream_id_gen.get_max_token() if direction != 'b': defer.returnValue("s%d" % (token,)) else: diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index e1a9c0c261..a0e6b42b30 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -30,7 +30,7 @@ class TagsStore(SQLBaseStore): Returns: A deferred int. """ - return self._account_data_id_gen.get_max_token(self) + return self._account_data_id_gen.get_max_token() @cached() def get_tags_for_user(self, user_id): @@ -59,6 +59,59 @@ class TagsStore(SQLBaseStore): return deferred @defer.inlineCallbacks + def get_all_updated_tags(self, last_id, current_id, limit): + """Get all the client tags that have changed on the server + Args: + last_id(int): The position to fetch from. + current_id(int): The position to fetch up to. + Returns: + A deferred list of tuples of stream_id int, user_id string, + room_id string, tag string and content string. + """ + def get_all_updated_tags_txn(txn): + sql = ( + "SELECT stream_id, user_id, room_id" + " FROM room_tags_revisions as r" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() + + tag_ids = yield self.runInteraction( + "get_all_updated_tags", get_all_updated_tags_txn + ) + + def get_tag_content(txn, tag_ids): + sql = ( + "SELECT tag, content" + " FROM room_tags" + " WHERE user_id=? AND room_id=?" + ) + results = [] + for stream_id, user_id, room_id in tag_ids: + txn.execute(sql, (user_id, room_id)) + tags = [] + for tag, content in txn.fetchall(): + tags.append(json.dumps(tag) + ":" + content) + tag_json = "{" + ",".join(tags) + "}" + results.append((stream_id, user_id, room_id, tag_json)) + + return results + + batch_size = 50 + results = [] + for i in xrange(0, len(tag_ids), batch_size): + tags = yield self.runInteraction( + "get_all_updated_tag_content", + get_tag_content, + tag_ids[i:i + batch_size], + ) + results.extend(tags) + + defer.returnValue(results) + + @defer.inlineCallbacks def get_updated_tags(self, user_id, stream_id): """Get all the tags for the rooms where the tags have changed since the given version @@ -142,12 +195,12 @@ class TagsStore(SQLBaseStore): ) self._update_revision_txn(txn, user_id, room_id, next_id) - with (yield self._account_data_id_gen.get_next(self)) as next_id: + with self._account_data_id_gen.get_next() as next_id: yield self.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) - result = yield self._account_data_id_gen.get_max_token(self) + result = self._account_data_id_gen.get_max_token() defer.returnValue(result) @defer.inlineCallbacks @@ -164,12 +217,12 @@ class TagsStore(SQLBaseStore): txn.execute(sql, (user_id, room_id, tag)) self._update_revision_txn(txn, user_id, room_id, next_id) - with (yield self._account_data_id_gen.get_next(self)) as next_id: + with self._account_data_id_gen.get_next() as next_id: yield self.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) - result = yield self._account_data_id_gen.get_max_token(self) + result = self._account_data_id_gen.get_max_token() defer.returnValue(result) def _update_revision_txn(self, txn, user_id, room_id, next_id): diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 4475c451c1..d338dfcf0a 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -117,7 +117,7 @@ class TransactionStore(SQLBaseStore): def _prep_send_transaction(self, txn, transaction_id, destination, origin_server_ts): - next_id = self._transaction_id_gen.get_next_txn(txn) + next_id = self._transaction_id_gen.get_next() # First we find out what the prev_txns should be. # Since we know that we are only sending one transaction at a time, diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 5c522f4ab9..efe3f68e6e 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -13,51 +13,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from collections import deque import contextlib import threading class IdGenerator(object): - def __init__(self, table, column, store): + def __init__(self, db_conn, table, column): self.table = table self.column = column - self.store = store self._lock = threading.Lock() - self._next_id = None + cur = db_conn.cursor() + self._next_id = self._load_next_id(cur) + cur.close() - @defer.inlineCallbacks - def get_next(self): - if self._next_id is None: - yield self.store.runInteraction( - "IdGenerator_%s" % (self.table,), - self.get_next_txn, - ) + def _load_next_id(self, txn): + txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table,)) + val, = txn.fetchone() + return val + 1 if val else 1 + def get_next(self): with self._lock: i = self._next_id self._next_id += 1 - defer.returnValue(i) - - def get_next_txn(self, txn): - with self._lock: - if self._next_id: - i = self._next_id - self._next_id += 1 - return i - else: - txn.execute( - "SELECT MAX(%s) FROM %s" % (self.column, self.table,) - ) - - val, = txn.fetchone() - cur = val or 0 - cur += 1 - self._next_id = cur + 1 - - return cur + return i class StreamIdGenerator(object): @@ -69,7 +48,7 @@ class StreamIdGenerator(object): persistence of events can complete out of order. Usage: - with stream_id_gen.get_next_txn(txn) as stream_id: + with stream_id_gen.get_next() as stream_id: # ... persist event ... """ def __init__(self, db_conn, table, column): @@ -79,15 +58,21 @@ class StreamIdGenerator(object): self._lock = threading.Lock() cur = db_conn.cursor() - self._current_max = self._get_or_compute_current_max(cur) + self._current_max = self._load_current_max(cur) cur.close() self._unfinished_ids = deque() - def get_next(self, store): + def _load_current_max(self, txn): + txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table)) + rows = txn.fetchall() + val, = rows[0] + return int(val) if val else 1 + + def get_next(self): """ Usage: - with yield stream_id_gen.get_next as stream_id: + with stream_id_gen.get_next() as stream_id: # ... persist event ... """ with self._lock: @@ -106,10 +91,10 @@ class StreamIdGenerator(object): return manager() - def get_next_mult(self, store, n): + def get_next_mult(self, n): """ Usage: - with yield stream_id_gen.get_next(store, n) as stream_ids: + with stream_id_gen.get_next(n) as stream_ids: # ... persist events ... """ with self._lock: @@ -130,7 +115,7 @@ class StreamIdGenerator(object): return manager() - def get_max_token(self, store): + def get_max_token(self): """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. """ @@ -139,13 +124,3 @@ class StreamIdGenerator(object): return self._unfinished_ids[0] - 1 return self._current_max - - def _get_or_compute_current_max(self, txn): - with self._lock: - txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table)) - rows = txn.fetchall() - val, = rows[0] - - self._current_max = int(val) if val else 1 - - return self._current_max diff --git a/synapse/types.py b/synapse/types.py index 2095837ba6..d5bd95cbd3 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -73,6 +73,14 @@ class DomainSpecificString( """Return a string encoding the fields of the structure object.""" return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain) + @classmethod + def is_valid(cls, s): + try: + cls.from_string(s) + return True + except: + return False + __str__ = to_string @classmethod diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 133671e238..3b9da5b34a 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -42,7 +42,7 @@ class Clock(object): def time_msec(self): """Returns the current system time in miliseconds since epoch.""" - return self.time() * 1000 + return int(self.time() * 1000) def looping_call(self, f, msec): l = task.LoopingCall(f) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 277854ccbc..35544b19fd 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -28,6 +28,7 @@ from twisted.internet import defer from collections import OrderedDict +import os import functools import inspect import threading @@ -38,6 +39,9 @@ logger = logging.getLogger(__name__) _CacheSentinel = object() +CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) + + class Cache(object): def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False): @@ -140,6 +144,8 @@ class CacheDescriptor(object): """ def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False, inlineCallbacks=False): + max_entries = int(max_entries * CACHE_SIZE_FACTOR) + self.orig = orig if inlineCallbacks: diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 62cae99649..e863a8f8a9 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.util.caches import cache_counter, caches_by_name + import logging @@ -47,6 +49,8 @@ class ExpiringCache(object): self._cache = {} + caches_by_name[cache_name] = self._cache + def start(self): if not self._expiry_ms: # Don't bother starting the loop if things never expire @@ -72,7 +76,12 @@ class ExpiringCache(object): self._cache.pop(k) def __getitem__(self, key): - entry = self._cache[key] + try: + entry = self._cache[key] + cache_counter.inc_hits(self._cache_name) + except KeyError: + cache_counter.inc_misses(self._cache_name) + raise if self._reset_expiry_on_get: entry.time = self._clock.time_msec() @@ -105,9 +114,12 @@ class ExpiringCache(object): logger.debug( "[%s] _prune_cache before: %d, after len: %d", - self._cache_name, begin_length, len(self._cache.keys()) + self._cache_name, begin_length, len(self._cache) ) + def __len__(self): + return len(self._cache) + class _CacheEntry(object): def __init__(self, time, value): diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index b37f1c0725..a1aec7aa55 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -18,11 +18,15 @@ from synapse.util.caches import cache_counter, caches_by_name from blist import sorteddict import logging +import os logger = logging.getLogger(__name__) +CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) + + class StreamChangeCache(object): """Keeps track of the stream positions of the latest change in a set of entities. @@ -33,7 +37,7 @@ class StreamChangeCache(object): old then the cache will simply return all given entities. """ def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache={}): - self._max_size = max_size + self._max_size = int(max_size * CACHE_SIZE_FACTOR) self._entity_to_key = {} self._cache = sorteddict() self._earliest_known_stream_pos = current_stream_pos @@ -85,6 +89,22 @@ class StreamChangeCache(object): return result + def get_all_entities_changed(self, stream_pos): + """Returns all entites that have had new things since the given + position. If the position is too old it will return None. + """ + assert type(stream_pos) is int + + if stream_pos >= self._earliest_known_stream_pos: + keys = self._cache.keys() + i = keys.bisect_right(stream_pos) + + return ( + self._cache[k] for k in keys[i:] + ) + else: + return None + def entity_has_changed(self, entity, stream_pos): """Informs the cache that the entity has been changed at the given position. diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py new file mode 100644 index 0000000000..7412fc57a4 --- /dev/null +++ b/synapse/util/wheel_timer.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class _Entry(object): + __slots__ = ["end_key", "queue"] + + def __init__(self, end_key): + self.end_key = end_key + self.queue = [] + + +class WheelTimer(object): + """Stores arbitrary objects that will be returned after their timers have + expired. + """ + + def __init__(self, bucket_size=5000): + """ + Args: + bucket_size (int): Size of buckets in ms. Corresponds roughly to the + accuracy of the timer. + """ + self.bucket_size = bucket_size + self.entries = [] + self.current_tick = 0 + + def insert(self, now, obj, then): + """Inserts object into timer. + + Args: + now (int): Current time in msec + obj (object): Object to be inserted + then (int): When to return the object strictly after. + """ + then_key = int(then / self.bucket_size) + 1 + + if self.entries: + min_key = self.entries[0].end_key + max_key = self.entries[-1].end_key + + if then_key <= max_key: + # The max here is to protect against inserts for times in the past + self.entries[max(min_key, then_key) - min_key].queue.append(obj) + return + + next_key = int(now / self.bucket_size) + 1 + if self.entries: + last_key = self.entries[-1].end_key + else: + last_key = next_key + + # Handle the case when `then` is in the past and `entries` is empty. + then_key = max(last_key, then_key) + + # Add empty entries between the end of the current list and when we want + # to insert. This ensures there are no gaps. + self.entries.extend( + _Entry(key) for key in xrange(last_key, then_key + 1) + ) + + self.entries[-1].queue.append(obj) + + def fetch(self, now): + """Fetch any objects that have timed out + + Args: + now (ms): Current time in msec + + Returns: + list: List of objects that have timed out + """ + now_key = int(now / self.bucket_size) + + ret = [] + while self.entries and self.entries[0].end_key <= now_key: + ret.extend(self.entries.pop(0).queue) + + return ret + + def __len__(self): + l = 0 + for entry in self.entries: + l += len(entry.queue) + return l |