diff options
116 files changed, 3983 insertions, 4302 deletions
diff --git a/README.rst b/README.rst index 39a338c790..8a745259bf 100644 --- a/README.rst +++ b/README.rst @@ -565,4 +565,21 @@ sphinxcontrib-napoleon:: Building internal API documentation:: python setup.py build_sphinx - \ No newline at end of file + + + +Halp!! Synapse eats all my RAM! +=============================== + +Synapse's architecture is quite RAM hungry currently - we deliberately +cache a lot of recent room data and metadata in RAM in order to speed up +common requests. We'll improve this in future, but for now the easiest +way to either reduce the RAM usage (at the risk of slowing things down) +is to set the almost-undocumented ``SYNAPSE_CACHE_FACTOR`` environment +variable. Roughly speaking, a SYNAPSE_CACHE_FACTOR of 1.0 will max out +at around 3-4GB of resident memory - this is what we currently run the +matrix.org on. The default setting is currently 0.1, which is probably +around a ~700MB footprint. You can dial it down further to 0.02 if +desired, which targets roughly ~512MB. Conversely you can dial it up if +you need performance for lots of users and have a box with a lot of RAM. + diff --git a/jenkins.sh b/jenkins.sh index ed3bdf80fa..9646ac0b01 100755 --- a/jenkins.sh +++ b/jenkins.sh @@ -1,6 +1,9 @@ #!/bin/bash -eu +: ${WORKSPACE:="$(pwd)"} + export PYTHONDONTWRITEBYTECODE=yep +export SYNAPSE_CACHE_FACTOR=1 # Output test results as junit xml export TRIAL_FLAGS="--reporter=subunit" diff --git a/scripts-dev/tail-synapse.py b/scripts-dev/tail-synapse.py new file mode 100644 index 0000000000..18be711e92 --- /dev/null +++ b/scripts-dev/tail-synapse.py @@ -0,0 +1,67 @@ +import requests +import collections +import sys +import time +import json + +Entry = collections.namedtuple("Entry", "name position rows") + +ROW_TYPES = {} + + +def row_type_for_columns(name, column_names): + column_names = tuple(column_names) + row_type = ROW_TYPES.get((name, column_names)) + if row_type is None: + row_type = collections.namedtuple(name, column_names) + ROW_TYPES[(name, column_names)] = row_type + return row_type + + +def parse_response(content): + streams = json.loads(content) + result = {} + for name, value in streams.items(): + row_type = row_type_for_columns(name, value["field_names"]) + position = value["position"] + rows = [row_type(*row) for row in value["rows"]] + result[name] = Entry(name, position, rows) + return result + + +def replicate(server, streams): + return parse_response(requests.get( + server + "/_synapse/replication", + verify=False, + params=streams + ).content) + + +def main(): + server = sys.argv[1] + + streams = None + while not streams: + try: + streams = { + row.name: row.position + for row in replicate(server, {"streams":"-1"})["streams"].rows + } + except requests.exceptions.ConnectionError as e: + time.sleep(0.1) + + while True: + try: + results = replicate(server, streams) + except: + sys.stdout.write("connection_lost("+ repr(streams) + ")\n") + break + for update in results.values(): + for row in update.rows: + sys.stdout.write(repr(row) + "\n") + streams[update.name] = update.position + + + +if __name__=='__main__': + main() diff --git a/synapse/api/auth.py b/synapse/api/auth.py index e2f84c4d57..183245443c 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -434,31 +434,46 @@ class Auth(object): if event.user_id != invite_event.user_id: return False - try: - public_key = invite_event.content["public_key"] - if signed["mxid"] != event.state_key: - return False - if signed["token"] != token: - return False - for server, signature_block in signed["signatures"].items(): - for key_name, encoded_signature in signature_block.items(): - if not key_name.startswith("ed25519:"): - return False - verify_key = decode_verify_key_bytes( - key_name, - decode_base64(public_key) - ) - verify_signed_json(signed, server, verify_key) - # We got the public key from the invite, so we know that the - # correct server signed the signed bundle. - # The caller is responsible for checking that the signing - # server has not revoked that public key. - return True + if signed["mxid"] != event.state_key: return False - except (KeyError, SignatureVerifyException,): + if signed["token"] != token: return False + for public_key_object in self.get_public_keys(invite_event): + public_key = public_key_object["public_key"] + try: + for server, signature_block in signed["signatures"].items(): + for key_name, encoded_signature in signature_block.items(): + if not key_name.startswith("ed25519:"): + continue + verify_key = decode_verify_key_bytes( + key_name, + decode_base64(public_key) + ) + verify_signed_json(signed, server, verify_key) + + # We got the public key from the invite, so we know that the + # correct server signed the signed bundle. + # The caller is responsible for checking that the signing + # server has not revoked that public key. + return True + except (KeyError, SignatureVerifyException,): + continue + return False + + def get_public_keys(self, invite_event): + public_keys = [] + if "public_key" in invite_event.content: + o = { + "public_key": invite_event.content["public_key"], + } + if "key_validity_url" in invite_event.content: + o["key_validity_url"] = invite_event.content["key_validity_url"] + public_keys.append(o) + public_keys.extend(invite_event.content.get("public_keys", [])) + return public_keys + def _get_power_level_event(self, auth_events): key = (EventTypes.PowerLevels, "", ) return auth_events.get(key) diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 84cbe710b3..8cf4d6169c 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -32,7 +32,6 @@ class PresenceState(object): OFFLINE = u"offline" UNAVAILABLE = u"unavailable" ONLINE = u"online" - FREE_FOR_CHAT = u"free_for_chat" class JoinRules(object): diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 6eff83e5f8..cd699ef27f 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -198,7 +198,10 @@ class Filter(object): sender = event.get("sender", None) if not sender: # Presence events have their 'sender' in content.user_id - sender = event.get("content", {}).get("user_id", None) + content = event.get("content") + # account_data has been allowed to have non-dict content, so check type first + if isinstance(content, dict): + sender = content.get("user_id") return self.check_fields( event.get("room_id", None), diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 2b4be7bdd0..de5ee988f1 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -63,6 +63,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory from synapse.util.logcontext import LoggingContext from synapse.metrics.resource import MetricsResource, METRICS_PREFIX +from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX from synapse.federation.transport.server import TransportLayerServer from synapse import events @@ -169,6 +170,9 @@ class SynapseHomeServer(HomeServer): if name == "metrics" and self.get_config().enable_metrics: resources[METRICS_PREFIX] = MetricsResource(self) + if name == "replication": + resources[REPLICATION_PREFIX] = ReplicationResource(self) + root_resource = create_resource_tree(resources) if tls: reactor.listenSSL( diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 90718192dd..e8bfbe7cb5 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -543,8 +543,19 @@ class FederationServer(FederationBase): return event @defer.inlineCallbacks - def exchange_third_party_invite(self, invite): - ret = yield self.handler.exchange_third_party_invite(invite) + def exchange_third_party_invite( + self, + sender_user_id, + target_user_id, + room_id, + signed, + ): + ret = yield self.handler.exchange_third_party_invite( + sender_user_id, + target_user_id, + room_id, + signed, + ) defer.returnValue(ret) @defer.inlineCallbacks diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 65e054f7dd..6e92e2f8f4 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -425,7 +425,17 @@ class On3pidBindServlet(BaseFederationServlet): last_exception = None for invite in content["invites"]: try: - yield self.handler.exchange_third_party_invite(invite) + if "signed" not in invite or "token" not in invite["signed"]: + message = ("Rejecting received notification of third-" + "party invite without signed: %s" % (invite,)) + logger.info(message) + raise SynapseError(400, message) + yield self.handler.exchange_third_party_invite( + invite["sender"], + invite["mxid"], + invite["room_id"], + invite["signed"], + ) except Exception as e: last_exception = e if last_exception: diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 064e8723c8..bdade98bf7 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import LimitExceededError, SynapseError, AuthError from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.api.constants import Membership, EventTypes -from synapse.types import UserID, RoomAlias +from synapse.types import UserID, RoomAlias, Requester from synapse.push.action_generator import ActionGenerator from synapse.util.logcontext import PreserveLoggingContext @@ -53,9 +53,15 @@ class BaseHandler(object): self.event_builder_factory = hs.get_event_builder_factory() @defer.inlineCallbacks - def _filter_events_for_clients(self, user_tuples, events, event_id_to_state): + def filter_events_for_clients(self, user_tuples, events, event_id_to_state): """ Returns dict of user_id -> list of events that user is allowed to see. + + :param (str, bool) user_tuples: (user id, is_peeking) for each + user to be checked. is_peeking should be true if: + * the user is not currently a member of the room, and: + * the user has not been a member of the room since the given + events """ forgotten = yield defer.gatherResults([ self.store.who_forgot_in_room( @@ -72,18 +78,20 @@ class BaseHandler(object): def allowed(event, user_id, is_peeking): state = event_id_to_state[event.event_id] + # get the room_visibility at the time of the event. visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None) if visibility_event: visibility = visibility_event.content.get("history_visibility", "shared") else: visibility = "shared" + # if it was world_readable, it's easy: everyone can read it if visibility == "world_readable": return True - if is_peeking: - return False - + # get the user's membership at the time of the event. (or rather, + # just *after* the event. Which means that people can see their + # own join events, but not (currently) their own leave events.) membership_event = state.get((EventTypes.Member, user_id), None) if membership_event: if membership_event.event_id in event_id_forgotten: @@ -93,20 +101,29 @@ class BaseHandler(object): else: membership = None + # if the user was a member of the room at the time of the event, + # they can see it. if membership == Membership.JOIN: return True - if event.type == EventTypes.RoomHistoryVisibility: - return not is_peeking + if visibility == "joined": + # we weren't a member at the time of the event, so we can't + # see this event. + return False - if visibility == "shared": - return True - elif visibility == "joined": - return membership == Membership.JOIN elif visibility == "invited": + # user can also see the event if they were *invited* at the time + # of the event. return membership == Membership.INVITE - return True + else: + # visibility is shared: user can also see the event if they have + # become a member since the event + # + # XXX: if the user has subsequently joined and then left again, + # ideally we would share history up to the point they left. But + # we don't know when they left. + return not is_peeking defer.returnValue({ user_id: [ @@ -119,7 +136,17 @@ class BaseHandler(object): @defer.inlineCallbacks def _filter_events_for_client(self, user_id, events, is_peeking=False): - # Assumes that user has at some point joined the room if not is_guest. + """ + Check which events a user is allowed to see + + :param str user_id: user id to be checked + :param [synapse.events.EventBase] events: list of events to be checked + :param bool is_peeking should be True if: + * the user is not currently a member of the room, and: + * the user has not been a member of the room since the given + events + :rtype [synapse.events.EventBase] + """ types = ( (EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id), @@ -128,7 +155,7 @@ class BaseHandler(object): frozenset(e.event_id for e in events), types=types ) - res = yield self._filter_events_for_clients( + res = yield self.filter_events_for_clients( [(user_id, is_peeking)], events, event_id_to_state ) defer.returnValue(res.get(user_id, [])) @@ -147,7 +174,7 @@ class BaseHandler(object): @defer.inlineCallbacks def _create_new_client_event(self, builder): - latest_ret = yield self.store.get_latest_events_in_room( + latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room( builder.room_id, ) @@ -156,7 +183,10 @@ class BaseHandler(object): else: depth = 1 - prev_events = [(e, h) for e, h, _ in latest_ret] + prev_events = [ + (event_id, prev_hashes) + for event_id, prev_hashes, _ in latest_ret + ] builder.prev_events = prev_events builder.depth = depth @@ -165,6 +195,31 @@ class BaseHandler(object): context = yield state_handler.compute_event_context(builder) + # If we've received an invite over federation, there are no latest + # events in the room, because we don't know enough about the graph + # fragment we received to treat it like a graph, so the above returned + # no relevant events. It may have returned some events (if we have + # joined and left the room), but not useful ones, like the invite. So we + # forcibly set our context to the invite we received over federation. + if ( + not self.is_host_in_room(context.current_state) and + builder.type == EventTypes.Member + ): + prev_member_event = yield self.store.get_room_member( + builder.sender, builder.room_id + ) + if prev_member_event: + builder.prev_events = ( + prev_member_event.event_id, + prev_member_event.prev_events + ) + + context = yield state_handler.compute_event_context( + builder, + old_state=(prev_member_event,), + outlier=True + ) + if builder.is_state(): builder.prev_state = yield self.store.add_event_hashes( context.prev_state_events @@ -187,10 +242,33 @@ class BaseHandler(object): (event, context,) ) + def is_host_in_room(self, current_state): + room_members = [ + (state_key, event.membership) + for ((event_type, state_key), event) in current_state.items() + if event_type == EventTypes.Member + ] + if len(room_members) == 0: + # Have we just created the room, and is this about to be the very + # first member event? + create_event = current_state.get(("m.room.create", "")) + if create_event: + return True + for (state_key, membership) in room_members: + if ( + UserID.from_string(state_key).domain == self.hs.hostname + and membership == Membership.JOIN + ): + return True + return False + @defer.inlineCallbacks - def handle_new_client_event(self, event, context, extra_users=[]): + def handle_new_client_event(self, event, context, ratelimit=True, extra_users=[]): # We now need to go and hit out to wherever we need to hit out to. + if ratelimit: + self.ratelimit(event.sender) + self.auth.check(event, auth_events=context.current_state) yield self.maybe_kick_guest_users(event, context.current_state.values()) @@ -215,6 +293,12 @@ class BaseHandler(object): if event.type == EventTypes.Member: if event.content["membership"] == Membership.INVITE: + def is_inviter_member_event(e): + return ( + e.type == EventTypes.Member and + e.sender == event.sender + ) + event.unsigned["invite_room_state"] = [ { "type": e.type, @@ -228,7 +312,7 @@ class BaseHandler(object): EventTypes.CanonicalAlias, EventTypes.RoomAvatar, EventTypes.Name, - ) + ) or is_inviter_member_event(e) ] invitee = UserID.from_string(event.state_key) @@ -316,7 +400,8 @@ class BaseHandler(object): if member_event.type != EventTypes.Member: continue - if not self.hs.is_mine(UserID.from_string(member_event.state_key)): + target_user = UserID.from_string(member_event.state_key) + if not self.hs.is_mine(target_user): continue if member_event.content["membership"] not in { @@ -338,18 +423,13 @@ class BaseHandler(object): # and having homeservers have their own users leave keeps more # of that decision-making and control local to the guest-having # homeserver. - message_handler = self.hs.get_handlers().message_handler - yield message_handler.create_and_send_event( - { - "type": EventTypes.Member, - "state_key": member_event.state_key, - "content": { - "membership": Membership.LEAVE, - "kind": "guest" - }, - "room_id": member_event.room_id, - "sender": member_event.state_key - }, + requester = Requester(target_user, "", True) + handler = self.hs.get_handlers().room_member_handler + yield handler.update_membership( + requester, + target_user, + member_event.room_id, + "leave", ratelimit=False, ) except Exception as e: diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 4efecb1ffd..e0a778e7ff 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -216,7 +216,7 @@ class DirectoryHandler(BaseHandler): aliases = yield self.store.get_aliases_for_room(room_id) msg_handler = self.hs.get_handlers().message_handler - yield msg_handler.create_and_send_event({ + yield msg_handler.create_and_send_nonmember_event({ "type": EventTypes.Aliases, "state_key": self.hs.hostname, "room_id": room_id, diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 4933c31c19..72a31a9755 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -19,6 +19,8 @@ from synapse.util.logutils import log_function from synapse.types import UserID from synapse.events.utils import serialize_event from synapse.util.logcontext import preserve_context_over_fn +from synapse.api.constants import Membership, EventTypes +from synapse.events import EventBase from ._base import BaseHandler @@ -126,11 +128,12 @@ class EventStreamHandler(BaseHandler): If `only_keys` is not None, events from keys will be sent down. """ auth_user = UserID.from_string(auth_user_id) + presence_handler = self.hs.get_handlers().presence_handler - try: - if affect_presence: - yield self.started_stream(auth_user) - + context = yield presence_handler.user_syncing( + auth_user_id, affect_presence=affect_presence, + ) + with context: if timeout: # If they've set a timeout set a minimum limit. timeout = max(timeout, 500) @@ -145,6 +148,34 @@ class EventStreamHandler(BaseHandler): is_guest=is_guest, explicit_room_id=room_id ) + # When the user joins a new room, or another user joins a currently + # joined room, we need to send down presence for those users. + to_add = [] + for event in events: + if not isinstance(event, EventBase): + continue + if event.type == EventTypes.Member: + if event.membership != Membership.JOIN: + continue + # Send down presence. + if event.state_key == auth_user_id: + # Send down presence for everyone in the room. + users = yield self.store.get_users_in_room(event.room_id) + states = yield presence_handler.get_states( + users, + as_event=True, + ) + to_add.extend(states) + else: + + ev = yield presence_handler.get_state( + UserID.from_string(event.state_key), + as_event=True, + ) + to_add.append(ev) + + events.extend(to_add) + time_now = self.clock.time_msec() chunks = [ @@ -159,10 +190,6 @@ class EventStreamHandler(BaseHandler): defer.returnValue(chunk) - finally: - if affect_presence: - self.stopped_stream(auth_user) - class EventHandler(BaseHandler): diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index da55d43541..3655b9e5e2 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -14,6 +14,9 @@ # limitations under the License. """Contains handlers for federation events.""" +from signedjson.key import decode_verify_key_bytes +from signedjson.sign import verify_signed_json +from unpaddedbase64 import decode_base64 from ._base import BaseHandler @@ -1620,19 +1623,15 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks @log_function - def exchange_third_party_invite(self, invite): - sender = invite["sender"] - room_id = invite["room_id"] - - if "signed" not in invite or "token" not in invite["signed"]: - logger.info( - "Discarding received notification of third party invite " - "without signed: %s" % (invite,) - ) - return - + def exchange_third_party_invite( + self, + sender_user_id, + target_user_id, + room_id, + signed, + ): third_party_invite = { - "signed": invite["signed"], + "signed": signed, } event_dict = { @@ -1642,8 +1641,8 @@ class FederationHandler(BaseHandler): "third_party_invite": third_party_invite, }, "room_id": room_id, - "sender": sender, - "state_key": invite["mxid"], + "sender": sender_user_id, + "state_key": target_user_id, } if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)): @@ -1656,11 +1655,11 @@ class FederationHandler(BaseHandler): ) self.auth.check(event, context.current_state) - yield self._validate_keyserver(event, auth_events=context.current_state) + yield self._check_signature(event, auth_events=context.current_state) member_handler = self.hs.get_handlers().room_member_handler - yield member_handler.send_membership_event(event, context) + yield member_handler.send_membership_event(event, context, from_client=False) else: - destinations = set([x.split(":", 1)[-1] for x in (sender, room_id)]) + destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id)) yield self.replication_layer.forward_third_party_invite( destinations, room_id, @@ -1681,13 +1680,13 @@ class FederationHandler(BaseHandler): ) self.auth.check(event, auth_events=context.current_state) - yield self._validate_keyserver(event, auth_events=context.current_state) + yield self._check_signature(event, auth_events=context.current_state) returned_invite = yield self.send_invite(origin, event) # TODO: Make sure the signatures actually are correct. event.signatures.update(returned_invite.signatures) member_handler = self.hs.get_handlers().room_member_handler - yield member_handler.send_membership_event(event, context) + yield member_handler.send_membership_event(event, context, from_client=False) @defer.inlineCallbacks def add_display_name_to_third_party_invite(self, event_dict, event, context): @@ -1711,17 +1710,69 @@ class FederationHandler(BaseHandler): defer.returnValue((event, context)) @defer.inlineCallbacks - def _validate_keyserver(self, event, auth_events): - token = event.content["third_party_invite"]["signed"]["token"] + def _check_signature(self, event, auth_events): + """ + Checks that the signature in the event is consistent with its invite. + :param event (Event): The m.room.member event to check + :param auth_events (dict<(event type, state_key), event>) + + :raises + AuthError if signature didn't match any keys, or key has been + revoked, + SynapseError if a transient error meant a key couldn't be checked + for revocation. + """ + signed = event.content["third_party_invite"]["signed"] + token = signed["token"] invite_event = auth_events.get( (EventTypes.ThirdPartyInvite, token,) ) + if not invite_event: + raise AuthError(403, "Could not find invite") + + last_exception = None + for public_key_object in self.hs.get_auth().get_public_keys(invite_event): + try: + for server, signature_block in signed["signatures"].items(): + for key_name, encoded_signature in signature_block.items(): + if not key_name.startswith("ed25519:"): + continue + + public_key = public_key_object["public_key"] + verify_key = decode_verify_key_bytes( + key_name, + decode_base64(public_key) + ) + verify_signed_json(signed, server, verify_key) + if "key_validity_url" in public_key_object: + yield self._check_key_revocation( + public_key, + public_key_object["key_validity_url"] + ) + return + except Exception as e: + last_exception = e + raise last_exception + + @defer.inlineCallbacks + def _check_key_revocation(self, public_key, url): + """ + Checks whether public_key has been revoked. + + :param public_key (str): base-64 encoded public key. + :param url (str): Key revocation URL. + + :raises + AuthError if they key has been revoked. + SynapseError if a transient error meant a key couldn't be checked + for revocation. + """ try: response = yield self.hs.get_simple_http_client().get_json( - invite_event.content["key_validity_url"], - {"public_key": invite_event.content["public_key"]} + url, + {"public_key": public_key} ) except Exception: raise SynapseError( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 82c8cb5f0c..afa7c9c36c 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -16,12 +16,11 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import AuthError, Codes +from synapse.api.errors import AuthError, Codes, SynapseError from synapse.streams.config import PaginationConfig from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.util import unwrapFirstError -from synapse.util.logcontext import PreserveLoggingContext from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.types import UserID, RoomStreamToken, StreamToken @@ -216,7 +215,7 @@ class MessageHandler(BaseHandler): defer.returnValue((event, context)) @defer.inlineCallbacks - def send_event(self, event, context, ratelimit=True, is_guest=False): + def send_nonmember_event(self, event, context, ratelimit=True): """ Persists and notifies local clients and federation of an event. @@ -226,55 +225,68 @@ class MessageHandler(BaseHandler): ratelimit (bool): Whether to rate limit this send. is_guest (bool): Whether the sender is a guest. """ + if event.type == EventTypes.Member: + raise SynapseError( + 500, + "Tried to send member event through non-member codepath" + ) + user = UserID.from_string(event.sender) assert self.hs.is_mine(user), "User must be our own: %s" % (user,) - if ratelimit: - self.ratelimit(event.sender) - if event.is_state(): - prev_state = context.current_state.get((event.type, event.state_key)) - if prev_state and event.user_id == prev_state.user_id: - prev_content = encode_canonical_json(prev_state.content) - next_content = encode_canonical_json(event.content) - if prev_content == next_content: - # Duplicate suppression for state updates with same sender - # and content. - defer.returnValue(prev_state) + prev_state = self.deduplicate_state_event(event, context) + if prev_state is not None: + defer.returnValue(prev_state) - if event.type == EventTypes.Member: - member_handler = self.hs.get_handlers().room_member_handler - yield member_handler.send_membership_event(event, context, is_guest=is_guest) - else: - yield self.handle_new_client_event( - event=event, - context=context, - ) + yield self.handle_new_client_event( + event=event, + context=context, + ratelimit=ratelimit, + ) if event.type == EventTypes.Message: presence = self.hs.get_handlers().presence_handler - with PreserveLoggingContext(): - presence.bump_presence_active_time(user) + yield presence.bump_presence_active_time(user) + + def deduplicate_state_event(self, event, context): + """ + Checks whether event is in the latest resolved state in context. + + If so, returns the version of the event in context. + Otherwise, returns None. + """ + prev_event = context.current_state.get((event.type, event.state_key)) + if prev_event and event.user_id == prev_event.user_id: + prev_content = encode_canonical_json(prev_event.content) + next_content = encode_canonical_json(event.content) + if prev_content == next_content: + return prev_event + return None @defer.inlineCallbacks - def create_and_send_event(self, event_dict, ratelimit=True, - token_id=None, txn_id=None, is_guest=False): + def create_and_send_nonmember_event( + self, + event_dict, + ratelimit=True, + token_id=None, + txn_id=None + ): """ Creates an event, then sends it. - See self.create_event and self.send_event. + See self.create_event and self.send_nonmember_event. """ event, context = yield self.create_event( event_dict, token_id=token_id, txn_id=txn_id ) - yield self.send_event( + yield self.send_nonmember_event( event, context, ratelimit=ratelimit, - is_guest=is_guest ) defer.returnValue(event) @@ -660,10 +672,6 @@ class MessageHandler(BaseHandler): room_id=room_id, ) - # TODO(paul): I wish I was called with user objects not user_id - # strings... - auth_user = UserID.from_string(user_id) - # TODO: These concurrently time_now = self.clock.time_msec() state = [ @@ -688,13 +696,11 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def get_presence(): states = yield presence_handler.get_states( - target_users=[UserID.from_string(m.user_id) for m in room_members], - auth_user=auth_user, + [m.user_id for m in room_members], as_event=True, - check_auth=False, ) - defer.returnValue(states.values()) + defer.returnValue(states) @defer.inlineCallbacks def get_receipts(): diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index b61394f2b5..f6cf343174 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -13,13 +13,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +"""This module is responsible for keeping track of presence status of local +and remote users. -from synapse.api.errors import SynapseError, AuthError +The methods that define policy are: + - PresenceHandler._update_states + - PresenceHandler._handle_timeouts + - should_notify +""" + +from twisted.internet import defer, reactor +from contextlib import contextmanager + +from synapse.api.errors import SynapseError from synapse.api.constants import PresenceState +from synapse.storage.presence import UserPresenceState -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logcontext import preserve_fn from synapse.util.logutils import log_function +from synapse.util.metrics import Measure +from synapse.util.wheel_timer import WheelTimer from synapse.types import UserID import synapse.metrics @@ -32,34 +45,32 @@ logger = logging.getLogger(__name__) metrics = synapse.metrics.get_metrics_for(__name__) +notified_presence_counter = metrics.register_counter("notified_presence") +federation_presence_out_counter = metrics.register_counter("federation_presence_out") +presence_updates_counter = metrics.register_counter("presence_updates") +timers_fired_counter = metrics.register_counter("timers_fired") +federation_presence_counter = metrics.register_counter("federation_presence") +bump_active_time_counter = metrics.register_counter("bump_active_time") -# Don't bother bumping "last active" time if it differs by less than 60 seconds -LAST_ACTIVE_GRANULARITY = 60 * 1000 - -# Keep no more than this number of offline serial revisions -MAX_OFFLINE_SERIALS = 1000 +# If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them +# "currently_active" +LAST_ACTIVE_GRANULARITY = 60 * 1000 -# TODO(paul): Maybe there's one of these I can steal from somewhere -def partition(l, func): - """Partition the list by the result of func applied to each element.""" - ret = {} - - for x in l: - key = func(x) - if key not in ret: - ret[key] = [] - ret[key].append(x) +# How long to wait until a new /events or /sync request before assuming +# the client has gone. +SYNC_ONLINE_TIMEOUT = 30 * 1000 - return ret +# How long to wait before marking the user as idle. Compared against last active +IDLE_TIMER = 5 * 60 * 1000 +# How often we expect remote servers to resend us presence. +FEDERATION_TIMEOUT = 30 * 60 * 1000 -def partitionbool(l, func): - def boolfunc(x): - return bool(func(x)) +# How often to resend presence to remote servers +FEDERATION_PING_INTERVAL = 25 * 60 * 1000 - ret = partition(l, boolfunc) - return ret.get(True, []), ret.get(False, []) +assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER def user_presence_changed(distributor, user, statuscache): @@ -72,45 +83,13 @@ def collect_presencelike_data(distributor, user, content): class PresenceHandler(BaseHandler): - STATE_LEVELS = { - PresenceState.OFFLINE: 0, - PresenceState.UNAVAILABLE: 1, - PresenceState.ONLINE: 2, - PresenceState.FREE_FOR_CHAT: 3, - } - def __init__(self, hs): super(PresenceHandler, self).__init__(hs) - - self.homeserver = hs - + self.hs = hs self.clock = hs.get_clock() - - distributor = hs.get_distributor() - distributor.observe("registered_user", self.registered_user) - - distributor.observe( - "started_user_eventstream", self.started_user_eventstream - ) - distributor.observe( - "stopped_user_eventstream", self.stopped_user_eventstream - ) - - distributor.observe("user_joined_room", self.user_joined_room) - - distributor.declare("collect_presencelike_data") - - distributor.declare("changed_presencelike_data") - distributor.observe( - "changed_presencelike_data", self.changed_presencelike_data - ) - - # outbound signal from the presence module to advertise when a user's - # presence has changed - distributor.declare("user_presence_changed") - - self.distributor = distributor - + self.store = hs.get_datastore() + self.wheel_timer = WheelTimer() + self.notifier = hs.get_notifier() self.federation = hs.get_replication_layer() self.federation.register_edu_handler( @@ -138,348 +117,552 @@ class PresenceHandler(BaseHandler): ) ) - # IN-MEMORY store, mapping local userparts to sets of local users to - # be informed of state changes. - self._local_pushmap = {} - # map local users to sets of remote /domain names/ who are interested - # in them - self._remote_sendmap = {} - # map remote users to sets of local users who're interested in them - self._remote_recvmap = {} - # list of (serial, set of(userids)) tuples, ordered by serial, latest - # first - self._remote_offline_serials = [] - - # map any user to a UserPresenceCache - self._user_cachemap = {} - self._user_cachemap_latest_serial = 0 - - # map room_ids to the latest presence serial for a member of that - # room - self._room_serials = {} + distributor = hs.get_distributor() + distributor.observe("user_joined_room", self.user_joined_room) + + active_presence = self.store.take_presence_startup_info() + + # A dictionary of the current state of users. This is prefilled with + # non-offline presence from the DB. We should fetch from the DB if + # we can't find a users presence in here. + self.user_to_current_state = { + state.user_id: state + for state in active_presence + } metrics.register_callback( - "userCachemap:size", - lambda: len(self._user_cachemap), + "user_to_current_state_size", lambda: len(self.user_to_current_state) ) - def _get_or_make_usercache(self, user): - """If the cache entry doesn't exist, initialise a new one.""" - if user not in self._user_cachemap: - self._user_cachemap[user] = UserPresenceCache() - return self._user_cachemap[user] - - def _get_or_offline_usercache(self, user): - """If the cache entry doesn't exist, return an OFFLINE one but do not - store it into the cache.""" - if user in self._user_cachemap: - return self._user_cachemap[user] - else: - return UserPresenceCache() + now = self.clock.time_msec() + for state in active_presence: + self.wheel_timer.insert( + now=now, + obj=state.user_id, + then=state.last_active_ts + IDLE_TIMER, + ) + self.wheel_timer.insert( + now=now, + obj=state.user_id, + then=state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT, + ) + if self.hs.is_mine_id(state.user_id): + self.wheel_timer.insert( + now=now, + obj=state.user_id, + then=state.last_federation_update_ts + FEDERATION_PING_INTERVAL, + ) + else: + self.wheel_timer.insert( + now=now, + obj=state.user_id, + then=state.last_federation_update_ts + FEDERATION_TIMEOUT, + ) - def registered_user(self, user): - return self.store.create_presence(user.localpart) + # Set of users who have presence in the `user_to_current_state` that + # have not yet been persisted + self.unpersisted_users_changes = set() - @defer.inlineCallbacks - def is_presence_visible(self, observer_user, observed_user): - assert(self.hs.is_mine(observed_user)) + reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown) - if observer_user == observed_user: - defer.returnValue(True) + self.serial_to_user = {} + self._next_serial = 1 - if (yield self.store.user_rooms_intersect( - [u.to_string() for u in observer_user, observed_user])): - defer.returnValue(True) + # Keeps track of the number of *ongoing* syncs. While this is non zero + # a user will never go offline. + self.user_to_num_current_syncs = {} - if (yield self.store.is_presence_visible( - observed_localpart=observed_user.localpart, - observer_userid=observer_user.to_string())): - defer.returnValue(True) + # Start a LoopingCall in 30s that fires every 5s. + # The initial delay is to allow disconnected clients a chance to + # reconnect before we treat them as offline. + self.clock.call_later( + 0 * 1000, + self.clock.looping_call, + self._handle_timeouts, + 5000, + ) - defer.returnValue(False) + metrics.register_callback("wheel_timer_size", lambda: len(self.wheel_timer)) @defer.inlineCallbacks - def get_state(self, target_user, auth_user, as_event=False, check_auth=True): - """Get the current presence state of the given user. + def _on_shutdown(self): + """Gets called when shutting down. This lets us persist any updates that + we haven't yet persisted, e.g. updates that only changes some internal + timers. This allows changes to persist across startup without having to + persist every single change. + + If this does not run it simply means that some of the timers will fire + earlier than they should when synapse is restarted. This affect of this + is some spurious presence changes that will self-correct. + """ + logger.info( + "Performing _on_shutdown. Persiting %d unpersisted changes", + len(self.user_to_current_state) + ) - Args: - target_user (UserID): The user whose presence we want - auth_user (UserID): The user requesting the presence, used for - checking if said user is allowed to see the persence of the - `target_user` - as_event (bool): Format the return as an event or not? - check_auth (bool): Perform the auth checks or not? + if self.unpersisted_users_changes: + yield self.store.update_presence([ + self.user_to_current_state[user_id] + for user_id in self.unpersisted_users_changes + ]) + logger.info("Finished _on_shutdown") - Returns: - dict: The presence state of the `target_user`, whose format depends - on the `as_event` argument. + @defer.inlineCallbacks + def _update_states(self, new_states): + """Updates presence of users. Sets the appropriate timeouts. Pokes + the notifier and federation if and only if the changed presence state + should be sent to clients/servers. """ - if self.hs.is_mine(target_user): - if check_auth: - visible = yield self.is_presence_visible( - observer_user=auth_user, - observed_user=target_user - ) + now = self.clock.time_msec() - if not visible: - raise SynapseError(404, "Presence information not visible") + with Measure(self.clock, "presence_update_states"): - if target_user in self._user_cachemap: - state = self._user_cachemap[target_user].get_state() - else: - state = yield self.store.get_presence_state(target_user.localpart) - if "mtime" in state: - del state["mtime"] - state["presence"] = state.pop("state") - else: - # TODO(paul): Have remote server send us permissions set - state = self._get_or_offline_usercache(target_user).get_state() + # NOTE: We purposefully don't yield between now and when we've + # calculated what we want to do with the new states, to avoid races. - if "last_active" in state: - state["last_active_ago"] = int( - self.clock.time_msec() - state.pop("last_active") - ) + to_notify = {} # Changes we want to notify everyone about + to_federation_ping = {} # These need sending keep-alives - if as_event: - content = state + for new_state in new_states: + user_id = new_state.user_id - content["user_id"] = target_user.to_string() + # Its fine to not hit the database here, as the only thing not in + # the current state cache are OFFLINE states, where the only field + # of interest is last_active which is safe enough to assume is 0 + # here. + prev_state = self.user_to_current_state.get( + user_id, UserPresenceState.default(user_id) + ) - if "last_active" in content: - content["last_active_ago"] = int( - self._clock.time_msec() - content.pop("last_active") + new_state, should_notify, should_ping = handle_update( + prev_state, new_state, + is_mine=self.hs.is_mine_id(user_id), + wheel_timer=self.wheel_timer, + now=now ) - defer.returnValue({"type": "m.presence", "content": content}) - else: - defer.returnValue(state) + self.user_to_current_state[user_id] = new_state - @defer.inlineCallbacks - def get_states(self, target_users, auth_user, as_event=False, check_auth=True): - """A batched version of the `get_state` method that accepts a list of - `target_users` + if should_notify: + to_notify[user_id] = new_state + elif should_ping: + to_federation_ping[user_id] = new_state - Args: - target_users (list): The list of UserID's whose presence we want - auth_user (UserID): The user requesting the presence, used for - checking if said user is allowed to see the persence of the - `target_users` - as_event (bool): Format the return as an event or not? - check_auth (bool): Perform the auth checks or not? + # TODO: We should probably ensure there are no races hereafter - Returns: - dict: A mapping from user -> presence_state - """ - local_users, remote_users = partitionbool( - target_users, - lambda u: self.hs.is_mine(u) - ) + presence_updates_counter.inc_by(len(new_states)) + + if to_notify: + notified_presence_counter.inc_by(len(to_notify)) + yield self._persist_and_notify(to_notify.values()) + + self.unpersisted_users_changes |= set(s.user_id for s in new_states) + self.unpersisted_users_changes -= set(to_notify.keys()) - if check_auth: - for user in local_users: - visible = yield self.is_presence_visible( - observer_user=auth_user, - observed_user=user + to_federation_ping = { + user_id: state for user_id, state in to_federation_ping.items() + if user_id not in to_notify + } + if to_federation_ping: + federation_presence_out_counter.inc_by(len(to_federation_ping)) + + _, _, hosts_to_states = yield self._get_interested_parties( + to_federation_ping.values() ) - if not visible: - raise SynapseError(404, "Presence information not visible") + self._push_to_remotes(hosts_to_states) + + def _handle_timeouts(self): + """Checks the presence of users that have timed out and updates as + appropriate. + """ + now = self.clock.time_msec() + + with Measure(self.clock, "presence_handle_timeouts"): + # Fetch the list of users that *may* have timed out. Things may have + # changed since the timeout was set, so we won't necessarily have to + # take any action. + users_to_check = self.wheel_timer.fetch(now) - results = {} - if local_users: - for user in local_users: - if user in self._user_cachemap: - results[user] = self._user_cachemap[user].get_state() + states = [ + self.user_to_current_state.get( + user_id, UserPresenceState.default(user_id) + ) + for user_id in set(users_to_check) + ] - local_to_user = {u.localpart: u for u in local_users} + timers_fired_counter.inc_by(len(states)) - states = yield self.store.get_presence_states( - [u.localpart for u in local_users if u not in results] + changes = handle_timeouts( + states, + is_mine_fn=self.hs.is_mine_id, + user_to_num_current_syncs=self.user_to_num_current_syncs, + now=now, ) - for local_part, state in states.items(): - if state is None: - continue - res = {"presence": state["state"]} - if "status_msg" in state and state["status_msg"]: - res["status_msg"] = state["status_msg"] - results[local_to_user[local_part]] = res - - for user in remote_users: - # TODO(paul): Have remote server send us permissions set - results[user] = self._get_or_offline_usercache(user).get_state() - - for state in results.values(): - if "last_active" in state: - state["last_active_ago"] = int( - self.clock.time_msec() - state.pop("last_active") - ) + preserve_fn(self._update_states)(changes) - if as_event: - for user, state in results.items(): - content = state - content["user_id"] = user.to_string() + @defer.inlineCallbacks + def bump_presence_active_time(self, user): + """We've seen the user do something that indicates they're interacting + with the app. + """ + user_id = user.to_string() - if "last_active" in content: - content["last_active_ago"] = int( - self._clock.time_msec() - content.pop("last_active") - ) + bump_active_time_counter.inc() - results[user] = {"type": "m.presence", "content": content} + prev_state = yield self.current_state_for_user(user_id) - defer.returnValue(results) + new_fields = { + "last_active_ts": self.clock.time_msec(), + } + if prev_state.state == PresenceState.UNAVAILABLE: + new_fields["state"] = PresenceState.ONLINE + + yield self._update_states([prev_state.copy_and_replace(**new_fields)]) @defer.inlineCallbacks - @log_function - def set_state(self, target_user, auth_user, state): - # return - # TODO (erikj): Turn this back on. Why did we end up sending EDUs - # everywhere? + def user_syncing(self, user_id, affect_presence=True): + """Returns a context manager that should surround any stream requests + from the user. - if not self.hs.is_mine(target_user): - raise SynapseError(400, "User is not hosted on this Home Server") + This allows us to keep track of who is currently streaming and who isn't + without having to have timers outside of this module to avoid flickering + when users disconnect/reconnect. - if target_user != auth_user: - raise AuthError(400, "Cannot set another user's presence") + Args: + user_id (str) + affect_presence (bool): If false this function will be a no-op. + Useful for streams that are not associated with an actual + client that is being used by a user. + """ + if affect_presence: + curr_sync = self.user_to_num_current_syncs.get(user_id, 0) + self.user_to_num_current_syncs[user_id] = curr_sync + 1 + + prev_state = yield self.current_state_for_user(user_id) + if prev_state.state == PresenceState.OFFLINE: + # If they're currently offline then bring them online, otherwise + # just update the last sync times. + yield self._update_states([prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=self.clock.time_msec(), + last_user_sync_ts=self.clock.time_msec(), + )]) + else: + yield self._update_states([prev_state.copy_and_replace( + last_user_sync_ts=self.clock.time_msec(), + )]) - if "status_msg" not in state: - state["status_msg"] = None + @defer.inlineCallbacks + def _end(): + if affect_presence: + self.user_to_num_current_syncs[user_id] -= 1 - for k in state.keys(): - if k not in ("presence", "status_msg"): - raise SynapseError( - 400, "Unexpected presence state key '%s'" % (k,) - ) + prev_state = yield self.current_state_for_user(user_id) + yield self._update_states([prev_state.copy_and_replace( + last_user_sync_ts=self.clock.time_msec(), + )]) - if state["presence"] not in self.STATE_LEVELS: - raise SynapseError(400, "'%s' is not a valid presence state" % ( - state["presence"], - )) + @contextmanager + def _user_syncing(): + try: + yield + finally: + preserve_fn(_end)() - logger.debug("Updating presence state of %s to %s", - target_user.localpart, state["presence"]) + defer.returnValue(_user_syncing()) - state_to_store = dict(state) - state_to_store["state"] = state_to_store.pop("presence") + @defer.inlineCallbacks + def current_state_for_user(self, user_id): + """Get the current presence state for a user. + """ + res = yield self.current_state_for_users([user_id]) + defer.returnValue(res[user_id]) - statuscache = self._get_or_offline_usercache(target_user) - was_level = self.STATE_LEVELS[statuscache.get_state()["presence"]] - now_level = self.STATE_LEVELS[state["presence"]] + @defer.inlineCallbacks + def current_state_for_users(self, user_ids): + """Get the current presence state for multiple users. - yield self.store.set_presence_state( - target_user.localpart, state_to_store - ) - yield collect_presencelike_data(self.distributor, target_user, state) + Returns: + dict: `user_id` -> `UserPresenceState` + """ + states = { + user_id: self.user_to_current_state.get(user_id, None) + for user_id in user_ids + } + + missing = [user_id for user_id, state in states.items() if not state] + if missing: + # There are things not in our in memory cache. Lets pull them out of + # the database. + res = yield self.store.get_presence_for_users(missing) + states.update({state.user_id: state for state in res}) + + missing = [user_id for user_id, state in states.items() if not state] + if missing: + new = { + user_id: UserPresenceState.default(user_id) + for user_id in missing + } + states.update(new) + self.user_to_current_state.update(new) + + defer.returnValue(states) + + @defer.inlineCallbacks + def _get_interested_parties(self, states): + """Given a list of states return which entities (rooms, users, servers) + are interested in the given states. + + Returns: + 3-tuple: `(room_ids_to_states, users_to_states, hosts_to_states)`, + with each item being a dict of `entity_name` -> `[UserPresenceState]` + """ + room_ids_to_states = {} + users_to_states = {} + for state in states: + events = yield self.store.get_rooms_for_user(state.user_id) + for e in events: + room_ids_to_states.setdefault(e.room_id, []).append(state) + + plist = yield self.store.get_presence_list_observers_accepted(state.user_id) + for u in plist: + users_to_states.setdefault(u, []).append(state) + + # Always notify self + users_to_states.setdefault(state.user_id, []).append(state) + + hosts_to_states = {} + for room_id, states in room_ids_to_states.items(): + local_states = filter(lambda s: self.hs.is_mine_id(s.user_id), states) + if not local_states: + continue - if now_level > was_level: - state["last_active"] = self.clock.time_msec() + hosts = yield self.store.get_joined_hosts_for_room(room_id) + for host in hosts: + hosts_to_states.setdefault(host, []).extend(local_states) - now_online = state["presence"] != PresenceState.OFFLINE - was_polling = target_user in self._user_cachemap + for user_id, states in users_to_states.items(): + local_states = filter(lambda s: self.hs.is_mine_id(s.user_id), states) + if not local_states: + continue - if now_online and not was_polling: - yield self.start_polling_presence(target_user, state=state) - elif not now_online and was_polling: - yield self.stop_polling_presence(target_user) + host = UserID.from_string(user_id).domain + hosts_to_states.setdefault(host, []).extend(local_states) - # TODO(paul): perform a presence push as part of start/stop poll so - # we don't have to do this all the time - yield self.changed_presencelike_data(target_user, state) + # TODO: de-dup hosts_to_states, as a single host might have multiple + # of same presence - def bump_presence_active_time(self, user, now=None): - if now is None: - now = self.clock.time_msec() + defer.returnValue((room_ids_to_states, users_to_states, hosts_to_states)) - prev_state = self._get_or_make_usercache(user) - if now - prev_state.state.get("last_active", 0) < LAST_ACTIVE_GRANULARITY: - return + @defer.inlineCallbacks + def _persist_and_notify(self, states): + """Persist states in the database, poke the notifier and send to + interested remote servers + """ + stream_id, max_token = yield self.store.update_presence(states) + + parties = yield self._get_interested_parties(states) + room_ids_to_states, users_to_states, hosts_to_states = parties + + self.notifier.on_new_event( + "presence_key", stream_id, rooms=room_ids_to_states.keys(), + users=[UserID.from_string(u) for u in users_to_states.keys()] + ) - with PreserveLoggingContext(): - self.changed_presencelike_data(user, {"last_active": now}) + self._push_to_remotes(hosts_to_states) - def get_joined_rooms_for_user(self, user): - """Get the list of rooms a user is joined to. + def _push_to_remotes(self, hosts_to_states): + """Sends state updates to remote servers. Args: - user(UserID): The user. - Returns: - A Deferred of a list of room id strings. + hosts_to_states (dict): Mapping `server_name` -> `[UserPresenceState]` """ - rm_handler = self.homeserver.get_handlers().room_member_handler - return rm_handler.get_joined_rooms_for_user(user) + now = self.clock.time_msec() + for host, states in hosts_to_states.items(): + self.federation.send_edu( + destination=host, + edu_type="m.presence", + content={ + "push": [ + _format_user_presence_state(state, now) + for state in states + ] + } + ) - def get_joined_users_for_room_id(self, room_id): - rm_handler = self.homeserver.get_handlers().room_member_handler - return rm_handler.get_room_members(room_id) + @defer.inlineCallbacks + def incoming_presence(self, origin, content): + """Called when we receive a `m.presence` EDU from a remote server. + """ + now = self.clock.time_msec() + updates = [] + for push in content.get("push", []): + # A "push" contains a list of presence that we are probably interested + # in. + # TODO: Actually check if we're interested, rather than blindly + # accepting presence updates. + user_id = push.get("user_id", None) + if not user_id: + logger.info( + "Got presence update from %r with no 'user_id': %r", + origin, push, + ) + continue + + presence_state = push.get("presence", None) + if not presence_state: + logger.info( + "Got presence update from %r with no 'presence_state': %r", + origin, push, + ) + continue + + new_fields = { + "state": presence_state, + "last_federation_update_ts": now, + } + + last_active_ago = push.get("last_active_ago", None) + if last_active_ago is not None: + new_fields["last_active_ts"] = now - last_active_ago + + new_fields["status_msg"] = push.get("status_msg", None) + new_fields["currently_active"] = push.get("currently_active", False) + + prev_state = yield self.current_state_for_user(user_id) + updates.append(prev_state.copy_and_replace(**new_fields)) + + if updates: + federation_presence_counter.inc_by(len(updates)) + yield self._update_states(updates) @defer.inlineCallbacks - def changed_presencelike_data(self, user, state): - """Updates the presence state of a local user. + def get_state(self, target_user, as_event=False): + results = yield self.get_states( + [target_user.to_string()], + as_event=as_event, + ) + + defer.returnValue(results[0]) + + @defer.inlineCallbacks + def get_states(self, target_user_ids, as_event=False): + """Get the presence state for users. Args: - user(UserID): The user being updated. - state(dict): The new presence state for the user. + target_user_ids (list) + as_event (bool): Whether to format it as a client event or not. + Returns: - A Deferred + list """ - self._user_cachemap_latest_serial += 1 - statuscache = yield self.update_presence_cache(user, state) - yield self.push_presence(user, statuscache=statuscache) - @log_function - def started_user_eventstream(self, user): - # TODO(paul): Use "last online" state - return self.set_state(user, user, {"presence": PresenceState.ONLINE}) + updates = yield self.current_state_for_users(target_user_ids) + updates = updates.values() - @log_function - def stopped_user_eventstream(self, user): - # TODO(paul): Save current state as "last online" state - return self.set_state(user, user, {"presence": PresenceState.OFFLINE}) + for user_id in set(target_user_ids) - set(u.user_id for u in updates): + updates.append(UserPresenceState.default(user_id)) + + now = self.clock.time_msec() + if as_event: + defer.returnValue([ + { + "type": "m.presence", + "content": _format_user_presence_state(state, now), + } + for state in updates + ]) + else: + defer.returnValue([ + _format_user_presence_state(state, now) for state in updates + ]) @defer.inlineCallbacks - def user_joined_room(self, user, room_id): - """Called via the distributor whenever a user joins a room. - Notifies the new member of the presence of the current members. - Notifies the current members of the room of the new member's presence. + def set_state(self, target_user, state): + """Set the presence state of the user. + """ + status_msg = state.get("status_msg", None) + presence = state["presence"] - Args: - user(UserID): The user who joined the room. - room_id(str): The room id the user joined. + valid_presence = ( + PresenceState.ONLINE, PresenceState.UNAVAILABLE, PresenceState.OFFLINE + ) + if presence not in valid_presence: + raise SynapseError(400, "Invalid presence state") + + user_id = target_user.to_string() + + prev_state = yield self.current_state_for_user(user_id) + + new_fields = { + "state": presence, + "status_msg": status_msg if presence != PresenceState.OFFLINE else None + } + + if presence == PresenceState.ONLINE: + new_fields["last_active_ts"] = self.clock.time_msec() + + yield self._update_states([prev_state.copy_and_replace(**new_fields)]) + + @defer.inlineCallbacks + def user_joined_room(self, user, room_id): + """Called (via the distributor) when a user joins a room. This funciton + sends presence updates to servers, either: + 1. the joining user is a local user and we send their presence to + all servers in the room. + 2. the joining user is a remote user and so we send presence for all + local users in the room. """ + # We only need to send presence to servers that don't have it yet. We + # don't need to send to local clients here, as that is done as part + # of the event stream/sync. + # TODO: Only send to servers not already in the room. if self.hs.is_mine(user): - # No actual update but we need to bump the serial anyway for the - # event source - self._user_cachemap_latest_serial += 1 - statuscache = yield self.update_presence_cache( - user, room_ids=[room_id] - ) - self.push_update_to_local_and_remote( - observed_user=user, - room_ids=[room_id], - statuscache=statuscache, - ) + state = yield self.current_state_for_user(user.to_string()) - # We also want to tell them about current presence of people. - curr_users = yield self.get_joined_users_for_room_id(room_id) + hosts = yield self.store.get_joined_hosts_for_room(room_id) + self._push_to_remotes({host: (state,) for host in hosts}) + else: + user_ids = yield self.store.get_users_in_room(room_id) + user_ids = filter(self.hs.is_mine_id, user_ids) - for local_user in [c for c in curr_users if self.hs.is_mine(c)]: - statuscache = yield self.update_presence_cache( - local_user, room_ids=[room_id], add_to_cache=False - ) + states = yield self.current_state_for_users(user_ids) - with PreserveLoggingContext(): - self.push_update_to_local_and_remote( - observed_user=local_user, - users_to_push=[user], - statuscache=statuscache, - ) + self._push_to_remotes({user.domain: states.values()}) @defer.inlineCallbacks - def send_presence_invite(self, observer_user, observed_user): - """Request the presence of a local or remote user for a local user""" + def get_presence_list(self, observer_user, accepted=None): + """Returns the presence for all users in their presence list. + """ if not self.hs.is_mine(observer_user): raise SynapseError(400, "User is not hosted on this Home Server") + presence_list = yield self.store.get_presence_list( + observer_user.localpart, accepted=accepted + ) + + results = yield self.get_states( + target_user_ids=[row["observed_user_id"] for row in presence_list], + as_event=False, + ) + + is_accepted = { + row["observed_user_id"]: row["accepted"] for row in presence_list + } + + for result in results: + result.update({ + "accepted": is_accepted, + }) + + defer.returnValue(results) + + @defer.inlineCallbacks + def send_presence_invite(self, observer_user, observed_user): + """Sends a presence invite. + """ yield self.store.add_presence_list_pending( observer_user.localpart, observed_user.to_string() ) @@ -497,59 +680,40 @@ class PresenceHandler(BaseHandler): ) @defer.inlineCallbacks - def _should_accept_invite(self, observed_user, observer_user): - if not self.hs.is_mine(observed_user): - defer.returnValue(False) - - row = yield self.store.has_presence_state(observed_user.localpart) - if not row: - defer.returnValue(False) - - # TODO(paul): Eventually we'll ask the user's permission for this - # before accepting. For now just accept any invite request - defer.returnValue(True) - - @defer.inlineCallbacks def invite_presence(self, observed_user, observer_user): - """Handles a m.presence_invite EDU. A remote or local user has - requested presence updates for a local user. If the invite is accepted - then allow the local or remote user to see the presence of the local - user. - - Args: - observed_user(UserID): The local user whose presence is requested. - observer_user(UserID): The remote or local user requesting presence. + """Handles new presence invites. """ - accept = yield self._should_accept_invite(observed_user, observer_user) - - if accept: - yield self.store.allow_presence_visible( - observed_user.localpart, observer_user.to_string() - ) + if not self.hs.is_mine(observed_user): + raise SynapseError(400, "User is not hosted on this Home Server") + # TODO: Don't auto accept if self.hs.is_mine(observer_user): - if accept: - yield self.accept_presence(observed_user, observer_user) - else: - yield self.deny_presence(observed_user, observer_user) + yield self.accept_presence(observed_user, observer_user) else: - edu_type = "m.presence_accept" if accept else "m.presence_deny" - - yield self.federation.send_edu( + self.federation.send_edu( destination=observer_user.domain, - edu_type=edu_type, + edu_type="m.presence_accept", content={ "observed_user": observed_user.to_string(), "observer_user": observer_user.to_string(), } ) + state_dict = yield self.get_state(observed_user, as_event=False) + + self.federation.send_edu( + destination=observer_user.domain, + edu_type="m.presence", + content={ + "push": [state_dict] + } + ) + @defer.inlineCallbacks def accept_presence(self, observed_user, observer_user): """Handles a m.presence_accept EDU. Mark a presence invite from a local or remote user as accepted in a local user's presence list. Starts polling for presence updates from the local or remote user. - Args: observed_user(UserID): The user to update in the presence list. observer_user(UserID): The owner of the presence list to update. @@ -558,15 +722,10 @@ class PresenceHandler(BaseHandler): observer_user.localpart, observed_user.to_string() ) - yield self.start_polling_presence( - observer_user, target_user=observed_user - ) - @defer.inlineCallbacks def deny_presence(self, observed_user, observer_user): """Handle a m.presence_deny EDU. Removes a local or remote user from a local user's presence list. - Args: observed_user(UserID): The local or remote user to remove from the list. @@ -584,7 +743,6 @@ class PresenceHandler(BaseHandler): def drop(self, observed_user, observer_user): """Remove a local or remote user from a local user's presence list and unsubscribe the local user from updates that user. - Args: observed_user(UserId): The local or remote user to remove from the list. @@ -599,710 +757,353 @@ class PresenceHandler(BaseHandler): observer_user.localpart, observed_user.to_string() ) - self.stop_polling_presence( - observer_user, target_user=observed_user - ) - - @defer.inlineCallbacks - def get_presence_list(self, observer_user, accepted=None): - """Get the presence list for a local user. The retured list includes - the current presence state for each user listed. - - Args: - observer_user(UserID): The local user whose presence list to fetch. - accepted(bool or None): If not none then only include users who - have or have not accepted the presence invite request. - Returns: - A Deferred list of presence state events. - """ - if not self.hs.is_mine(observer_user): - raise SynapseError(400, "User is not hosted on this Home Server") - - presence_list = yield self.store.get_presence_list( - observer_user.localpart, accepted=accepted - ) - - results = [] - for row in presence_list: - observed_user = UserID.from_string(row["observed_user_id"]) - result = { - "observed_user": observed_user, "accepted": row["accepted"] - } - result.update( - self._get_or_offline_usercache(observed_user).get_state() - ) - if "last_active" in result: - result["last_active_ago"] = int( - self.clock.time_msec() - result.pop("last_active") - ) - results.append(result) - - defer.returnValue(results) + # TODO: Inform the remote that we've dropped the presence list. @defer.inlineCallbacks - @log_function - def start_polling_presence(self, user, target_user=None, state=None): - """Subscribe a local user to presence updates from a local or remote - user. If no target_user is supplied then subscribe to all users stored - in the presence list for the local user. - - Additonally this pushes the current presence state of this user to all - target_users. That state can be provided directly or will be read from - the stored state for the local user. - - Also this attempts to notify the local user of the current state of - any local target users. - - Args: - user(UserID): The local user that whishes for presence updates. - target_user(UserID): The local or remote user whose updates are - wanted. - state(dict): Optional presence state for the local user. + def is_visible(self, observed_user, observer_user): + """Returns whether a user can see another user's presence. """ - logger.debug("Start polling for presence from %s", user) - - if target_user: - target_users = set([target_user]) - room_ids = [] - else: - presence = yield self.store.get_presence_list( - user.localpart, accepted=True - ) - target_users = set([ - UserID.from_string(x["observed_user_id"]) for x in presence - ]) + observer_rooms = yield self.store.get_rooms_for_user(observer_user.to_string()) + observed_rooms = yield self.store.get_rooms_for_user(observed_user.to_string()) - # Also include people in all my rooms + observer_room_ids = set(r.room_id for r in observer_rooms) + observed_room_ids = set(r.room_id for r in observed_rooms) - room_ids = yield self.get_joined_rooms_for_user(user) + if observer_room_ids & observed_room_ids: + defer.returnValue(True) - if state is None: - state = yield self.store.get_presence_state(user.localpart) - else: - # statuscache = self._get_or_make_usercache(user) - # self._user_cachemap_latest_serial += 1 - # statuscache.update(state, self._user_cachemap_latest_serial) - pass - - yield self.push_update_to_local_and_remote( - observed_user=user, - users_to_push=target_users, - room_ids=room_ids, - statuscache=self._get_or_make_usercache(user), + accepted_observers = yield self.store.get_presence_list_observers_accepted( + observed_user.to_string() ) - for target_user in target_users: - if self.hs.is_mine(target_user): - self._start_polling_local(user, target_user) - - # We want to tell the person that just came online - # presence state of people they are interested in? - self.push_update_to_clients( - users_to_push=[user], - ) - - deferreds = [] - remote_users = [u for u in target_users if not self.hs.is_mine(u)] - remoteusers_by_domain = partition(remote_users, lambda u: u.domain) - # Only poll for people in our get_presence_list - for domain in remoteusers_by_domain: - remoteusers = remoteusers_by_domain[domain] - - deferreds.append(self._start_polling_remote( - user, domain, remoteusers - )) - - yield defer.DeferredList(deferreds, consumeErrors=True) + defer.returnValue(observer_user.to_string() in accepted_observers) - def _start_polling_local(self, user, target_user): - """Subscribe a local user to presence updates for a local user - - Args: - user(UserId): The local user that wishes for updates. - target_user(UserId): The local users whose updates are wanted. + @defer.inlineCallbacks + def get_all_presence_updates(self, last_id, current_id): """ - target_localpart = target_user.localpart - - if target_localpart not in self._local_pushmap: - self._local_pushmap[target_localpart] = set() - - self._local_pushmap[target_localpart].add(user) - - def _start_polling_remote(self, user, domain, remoteusers): - """Subscribe a local user to presence updates for remote users on a - given remote domain. - - Args: - user(UserID): The local user that wishes for updates. - domain(str): The remote server the local user wants updates from. - remoteusers(UserID): The remote users that local user wants to be - told about. - Returns: - A Deferred. + Gets a list of presence update rows from between the given stream ids. + Each row has: + - stream_id(str) + - user_id(str) + - state(str) + - last_active_ts(int) + - last_federation_update_ts(int) + - last_user_sync_ts(int) + - status_msg(int) + - currently_active(int) """ - to_poll = set() - - for u in remoteusers: - if u not in self._remote_recvmap: - self._remote_recvmap[u] = set() - to_poll.add(u) - - self._remote_recvmap[u].add(user) - - if not to_poll: - return defer.succeed(None) + # TODO(markjh): replicate the unpersisted changes. + # This could use the in-memory stores for recent changes. + rows = yield self.store.get_all_presence_updates(last_id, current_id) + defer.returnValue(rows) - return self.federation.send_edu( - destination=domain, - edu_type="m.presence", - content={"poll": [u.to_string() for u in to_poll]} - ) - - @log_function - def stop_polling_presence(self, user, target_user=None): - """Unsubscribe a local user from presence updates from a local or - remote user. If no target user is supplied then unsubscribe the user - from all presence updates that the user had subscribed to. - - Args: - user(UserID): The local user that no longer wishes for updates. - target_user(UserID or None): The user whose updates are no longer - wanted. - Returns: - A Deferred. - """ - logger.debug("Stop polling for presence from %s", user) - if not target_user or self.hs.is_mine(target_user): - self._stop_polling_local(user, target_user=target_user) +def should_notify(old_state, new_state): + """Decides if a presence state change should be sent to interested parties. + """ + if old_state.status_msg != new_state.status_msg: + return True - deferreds = [] + if old_state.state == PresenceState.ONLINE: + if new_state.state != PresenceState.ONLINE: + # Always notify for online -> anything + return True - if target_user: - if target_user not in self._remote_recvmap: - return - target_users = set([target_user]) - else: - target_users = self._remote_recvmap.keys() + if new_state.currently_active != old_state.currently_active: + return True - remoteusers = [u for u in target_users - if user in self._remote_recvmap[u]] - remoteusers_by_domain = partition(remoteusers, lambda u: u.domain) + if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: + # Always notify for a transition where last active gets bumped. + return True - for domain in remoteusers_by_domain: - remoteusers = remoteusers_by_domain[domain] + if old_state.state != new_state.state: + return True - deferreds.append( - self._stop_polling_remote(user, domain, remoteusers) - ) + return False - return defer.DeferredList(deferreds, consumeErrors=True) - def _stop_polling_local(self, user, target_user): - """Unsubscribe a local user from presence updates from a local user on - this server. +def _format_user_presence_state(state, now): + """Convert UserPresenceState to a format that can be sent down to clients + and to other servers. + """ + content = { + "presence": state.state, + "user_id": state.user_id, + } + if state.last_active_ts: + content["last_active_ago"] = now - state.last_active_ts + if state.status_msg and state.state != PresenceState.OFFLINE: + content["status_msg"] = state.status_msg + if state.state == PresenceState.ONLINE: + content["currently_active"] = state.currently_active - Args: - user(UserID): The local user that no longer wishes for updates. - target_user(UserID): The user whose updates are no longer wanted. - """ - for localpart in self._local_pushmap.keys(): - if target_user and localpart != target_user.localpart: - continue + return content - if user in self._local_pushmap[localpart]: - self._local_pushmap[localpart].remove(user) - if not self._local_pushmap[localpart]: - del self._local_pushmap[localpart] +class PresenceEventSource(object): + def __init__(self, hs): + self.hs = hs + self.clock = hs.get_clock() + self.store = hs.get_datastore() + @defer.inlineCallbacks @log_function - def _stop_polling_remote(self, user, domain, remoteusers): - """Unsubscribe a local user from presence updates from remote users on - a given domain. - - Args: - user(UserID): The local user that no longer wishes for updates. - domain(str): The remote server to unsubscribe from. - remoteusers([UserID]): The users on that remote server that the - local user no longer wishes to be updated about. - Returns: - A Deferred. - """ - to_unpoll = set() - - for u in remoteusers: - self._remote_recvmap[u].remove(user) - - if not self._remote_recvmap[u]: - del self._remote_recvmap[u] - to_unpoll.add(u) + def get_new_events(self, user, from_key, room_ids=None, include_offline=True, + **kwargs): + # The process for getting presence events are: + # 1. Get the rooms the user is in. + # 2. Get the list of user in the rooms. + # 3. Get the list of users that are in the user's presence list. + # 4. If there is a from_key set, cross reference the list of users + # with the `presence_stream_cache` to see which ones we actually + # need to check. + # 5. Load current state for the users. + # + # We don't try and limit the presence updates by the current token, as + # sending down the rare duplicate is not a concern. + + with Measure(self.clock, "presence.get_new_events"): + user_id = user.to_string() + if from_key is not None: + from_key = int(from_key) + room_ids = room_ids or [] - if not to_unpoll: - return defer.succeed(None) + presence = self.hs.get_handlers().presence_handler + stream_change_cache = self.store.presence_stream_cache - return self.federation.send_edu( - destination=domain, - edu_type="m.presence", - content={"unpoll": [u.to_string() for u in to_unpoll]} - ) + if not room_ids: + rooms = yield self.store.get_rooms_for_user(user_id) + room_ids = set(e.room_id for e in rooms) + else: + room_ids = set(room_ids) + + max_token = self.store.get_current_presence_token() + + plist = yield self.store.get_presence_list_accepted(user.localpart) + friends = set(row["observed_user_id"] for row in plist) + friends.add(user_id) # So that we receive our own presence + + user_ids_changed = set() + changed = None + if from_key and max_token - from_key < 100: + # For small deltas, its quicker to get all changes and then + # work out if we share a room or they're in our presence list + changed = stream_change_cache.get_all_entities_changed(from_key) + + # get_all_entities_changed can return None + if changed is not None: + for other_user_id in changed: + if other_user_id in friends: + user_ids_changed.add(other_user_id) + continue + other_rooms = yield self.store.get_rooms_for_user(other_user_id) + if room_ids.intersection(e.room_id for e in other_rooms): + user_ids_changed.add(other_user_id) + continue + else: + # Too many possible updates. Find all users we can see and check + # if any of them have changed. + user_ids_to_check = set() + for room_id in room_ids: + users = yield self.store.get_users_in_room(room_id) + user_ids_to_check.update(users) + + user_ids_to_check.update(friends) + + # Always include yourself. Only really matters for when the user is + # not in any rooms, but still. + user_ids_to_check.add(user_id) + + if from_key: + user_ids_changed = stream_change_cache.get_entities_changed( + user_ids_to_check, from_key, + ) + else: + user_ids_changed = user_ids_to_check - @defer.inlineCallbacks - @log_function - def push_presence(self, user, statuscache): - """ - Notify local and remote users of a change in presence of a local user. - Pushes the update to local clients and remote domains that are directly - subscribed to the presence of the local user. - Also pushes that update to any local user or remote domain that shares - a room with the local user. + updates = yield presence.current_state_for_users(user_ids_changed) - Args: - user(UserID): The local user whose presence was updated. - statuscache(UserPresenceCache): Cache of the user's presence state - Returns: - A Deferred. - """ - assert(self.hs.is_mine(user)) + now = self.clock.time_msec() - logger.debug("Pushing presence update from %s", user) + defer.returnValue(([ + { + "type": "m.presence", + "content": _format_user_presence_state(s, now), + } + for s in updates.values() + if include_offline or s.state != PresenceState.OFFLINE + ], max_token)) - localusers = set(self._local_pushmap.get(user.localpart, set())) - remotedomains = set(self._remote_sendmap.get(user.localpart, set())) + def get_current_key(self): + return self.store.get_current_presence_token() - # Reflect users' status changes back to themselves, so UIs look nice - # and also user is informed of server-forced pushes - localusers.add(user) + def get_pagination_rows(self, user, pagination_config, key): + return self.get_new_events(user, from_key=None, include_offline=False) - room_ids = yield self.get_joined_rooms_for_user(user) - if not localusers and not room_ids: - defer.returnValue(None) +def handle_timeouts(user_states, is_mine_fn, user_to_num_current_syncs, now): + """Checks the presence of users that have timed out and updates as + appropriate. - yield self.push_update_to_local_and_remote( - observed_user=user, - users_to_push=localusers, - remote_domains=remotedomains, - room_ids=room_ids, - statuscache=statuscache, - ) - yield user_presence_changed(self.distributor, user, statuscache) + Args: + user_states(list): List of UserPresenceState's to check. + is_mine_fn (fn): Function that returns if a user_id is ours + user_to_num_current_syncs (dict): Mapping of user_id to number of currently + active syncs. + now (int): Current time in ms. - @defer.inlineCallbacks - def incoming_presence(self, origin, content): - """Handle an incoming m.presence EDU. - For each presence update in the "push" list update our local cache and - notify the appropriate local clients. Only clients that share a room - or are directly subscribed to the presence for a user should be - notified of the update. - For each subscription request in the "poll" list start pushing presence - updates to the remote server. - For unsubscribe request in the "unpoll" list stop pushing presence - updates to the remote server. + Returns: + List of UserPresenceState updates + """ + changes = {} # Actual changes we need to notify people about - Args: - orgin(str): The source of this m.presence EDU. - content(dict): The content of this m.presence EDU. - Returns: - A Deferred. - """ - deferreds = [] + for state in user_states: + is_mine = is_mine_fn(state.user_id) - for push in content.get("push", []): - user = UserID.from_string(push["user_id"]) + new_state = handle_timeout(state, is_mine, user_to_num_current_syncs, now) + if new_state: + changes[state.user_id] = new_state - logger.debug("Incoming presence update from %s", user) + return changes.values() - observers = set(self._remote_recvmap.get(user, set())) - if observers: - logger.debug( - " | %d interested local observers %r", len(observers), observers - ) - room_ids = yield self.get_joined_rooms_for_user(user) - if room_ids: - logger.debug(" | %d interested room IDs %r", len(room_ids), room_ids) +def handle_timeout(state, is_mine, user_to_num_current_syncs, now): + """Checks the presence of the user to see if any of the timers have elapsed - state = dict(push) - del state["user_id"] + Args: + state (UserPresenceState) + is_mine (bool): Whether the user is ours + user_to_num_current_syncs (dict): Mapping of user_id to number of currently + active syncs. + now (int): Current time in ms. - if "presence" not in state: - logger.warning( - "Received a presence 'push' EDU from %s without a" - " 'presence' key", origin + Returns: + A UserPresenceState update or None if no update. + """ + if state.state == PresenceState.OFFLINE: + # No timeouts are associated with offline states. + return None + + changed = False + user_id = state.user_id + + if is_mine: + if state.state == PresenceState.ONLINE: + if now - state.last_active_ts > IDLE_TIMER: + # Currently online, but last activity ages ago so auto + # idle + state = state.copy_and_replace( + state=PresenceState.UNAVAILABLE, ) - continue - - if "last_active_ago" in state: - state["last_active"] = int( - self.clock.time_msec() - state.pop("last_active_ago") + changed = True + elif now - state.last_active_ts > LAST_ACTIVE_GRANULARITY: + # So that we send down a notification that we've + # stopped updating. + changed = True + + if now - state.last_federation_update_ts > FEDERATION_PING_INTERVAL: + # Need to send ping to other servers to ensure they don't + # timeout and set us to offline + changed = True + + # If there are have been no sync for a while (and none ongoing), + # set presence to offline + if not user_to_num_current_syncs.get(user_id, 0): + if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT: + state = state.copy_and_replace( + state=PresenceState.OFFLINE, + status_msg=None, ) - - self._user_cachemap_latest_serial += 1 - yield self.update_presence_cache(user, state, room_ids=room_ids) - - if not observers and not room_ids: - logger.debug(" | no interested observers or room IDs") - continue - - self.push_update_to_clients( - users_to_push=observers, room_ids=room_ids + changed = True + else: + # We expect to be poked occaisonally by the other side. + # This is to protect against forgetful/buggy servers, so that + # no one gets stuck online forever. + if now - state.last_federation_update_ts > FEDERATION_TIMEOUT: + # The other side seems to have disappeared. + state = state.copy_and_replace( + state=PresenceState.OFFLINE, + status_msg=None, ) + changed = True - user_id = user.to_string() - - if state["presence"] == PresenceState.OFFLINE: - self._remote_offline_serials.insert( - 0, - (self._user_cachemap_latest_serial, set([user_id])) - ) - while len(self._remote_offline_serials) > MAX_OFFLINE_SERIALS: - self._remote_offline_serials.pop() # remove the oldest - if user in self._user_cachemap: - del self._user_cachemap[user] - else: - # Remove the user from remote_offline_serials now that they're - # no longer offline - for idx, elem in enumerate(self._remote_offline_serials): - (_, user_ids) = elem - user_ids.discard(user_id) - if not user_ids: - self._remote_offline_serials.pop(idx) - - for poll in content.get("poll", []): - user = UserID.from_string(poll) - - if not self.hs.is_mine(user): - continue + return state if changed else None - # TODO(paul) permissions checks - - if user not in self._remote_sendmap: - self._remote_sendmap[user] = set() - - self._remote_sendmap[user].add(origin) - - deferreds.append(self._push_presence_remote(user, origin)) - - for unpoll in content.get("unpoll", []): - user = UserID.from_string(unpoll) - - if not self.hs.is_mine(user): - continue - if user in self._remote_sendmap: - self._remote_sendmap[user].remove(origin) +def handle_update(prev_state, new_state, is_mine, wheel_timer, now): + """Given a presence update: + 1. Add any appropriate timers. + 2. Check if we should notify anyone. - if not self._remote_sendmap[user]: - del self._remote_sendmap[user] + Args: + prev_state (UserPresenceState) + new_state (UserPresenceState) + is_mine (bool): Whether the user is ours + wheel_timer (WheelTimer) + now (int): Time now in ms - yield defer.DeferredList(deferreds, consumeErrors=True) - - @defer.inlineCallbacks - def update_presence_cache(self, user, state={}, room_ids=None, - add_to_cache=True): - """Update the presence cache for a user with a new state and bump the - serial to the latest value. - - Args: - user(UserID): The user being updated - state(dict): The presence state being updated - room_ids(None or list of str): A list of room_ids to update. If - room_ids is None then fetch the list of room_ids the user is - joined to. - add_to_cache: Whether to add an entry to the presence cache if the - user isn't already in the cache. - Returns: - A Deferred UserPresenceCache for the user being updated. - """ - if room_ids is None: - room_ids = yield self.get_joined_rooms_for_user(user) - - for room_id in room_ids: - self._room_serials[room_id] = self._user_cachemap_latest_serial - if add_to_cache: - statuscache = self._get_or_make_usercache(user) - else: - statuscache = self._get_or_offline_usercache(user) - statuscache.update(state, serial=self._user_cachemap_latest_serial) - defer.returnValue(statuscache) - - @defer.inlineCallbacks - def push_update_to_local_and_remote(self, observed_user, statuscache, - users_to_push=[], room_ids=[], - remote_domains=[]): - """Notify local clients and remote servers of a change in the presence - of a user. - - Args: - observed_user(UserID): The user to push the presence state for. - statuscache(UserPresenceCache): The cache for the presence state to - push. - users_to_push([UserID]): A list of local and remote users to - notify. - room_ids([str]): Notify the local and remote occupants of these - rooms. - remote_domains([str]): A list of remote servers to notify in - addition to those implied by the users_to_push and the - room_ids. - Returns: - A Deferred. - """ - - localusers, remoteusers = partitionbool( - users_to_push, - lambda u: self.hs.is_mine(u) - ) - - localusers = set(localusers) - - self.push_update_to_clients( - users_to_push=localusers, room_ids=room_ids - ) - - remote_domains = set(remote_domains) - remote_domains |= set([r.domain for r in remoteusers]) - for room_id in room_ids: - remote_domains.update( - (yield self.store.get_joined_hosts_for_room(room_id)) - ) - - remote_domains.discard(self.hs.hostname) - - deferreds = [] - for domain in remote_domains: - logger.debug(" | push to remote domain %s", domain) - deferreds.append( - self._push_presence_remote( - observed_user, domain, state=statuscache.get_state() - ) + Returns: + 3-tuple: `(new_state, persist_and_notify, federation_ping)` where: + - new_state: is the state to actually persist + - persist_and_notify (bool): whether to persist and notify people + - federation_ping (bool): whether we should send a ping over federation + """ + user_id = new_state.user_id + + persist_and_notify = False + federation_ping = False + + # If the users are ours then we want to set up a bunch of timers + # to time things out. + if is_mine: + if new_state.state == PresenceState.ONLINE: + # Idle timer + wheel_timer.insert( + now=now, + obj=user_id, + then=new_state.last_active_ts + IDLE_TIMER ) - yield defer.DeferredList(deferreds, consumeErrors=True) - - defer.returnValue((localusers, remote_domains)) - - def push_update_to_clients(self, users_to_push=[], room_ids=[]): - """Notify clients of a new presence event. - - Args: - users_to_push([UserID]): List of users to notify. - room_ids([str]): List of room_ids to notify. - """ - with PreserveLoggingContext(): - self.notifier.on_new_event( - "presence_key", - self._user_cachemap_latest_serial, - users_to_push, - room_ids, + active = now - new_state.last_active_ts < LAST_ACTIVE_GRANULARITY + new_state = new_state.copy_and_replace( + currently_active=active, ) - @defer.inlineCallbacks - def _push_presence_remote(self, user, destination, state=None): - """Push a user's presence to a remote server. If a presence state event - that event is sent. Otherwise a new state event is constructed from the - stored presence state. - The last_active is replaced with last_active_ago in case the wallclock - time on the remote server is different to the time on this server. - Sends an EDU to the remote server with the current presence state. - - Args: - user(UserID): The user to push the presence state for. - destination(str): The remote server to send state to. - state(dict): The state to push, or None to use the current stored - state. - Returns: - A Deferred. - """ - if state is None: - state = yield self.store.get_presence_state(user.localpart) - del state["mtime"] - state["presence"] = state.pop("state") - - if user in self._user_cachemap: - state["last_active"] = ( - self._user_cachemap[user].get_state()["last_active"] + if active: + wheel_timer.insert( + now=now, + obj=user_id, + then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY ) - yield collect_presencelike_data(self.distributor, user, state) - - if "last_active" in state: - state = dict(state) - state["last_active_ago"] = int( - self.clock.time_msec() - state.pop("last_active") + if new_state.state != PresenceState.OFFLINE: + # User has stopped syncing + wheel_timer.insert( + now=now, + obj=user_id, + then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT ) - user_state = {"user_id": user.to_string(), } - user_state.update(state) - - yield self.federation.send_edu( - destination=destination, - edu_type="m.presence", - content={"push": [user_state, ], } - ) - - -class PresenceEventSource(object): - def __init__(self, hs): - self.hs = hs - self.clock = hs.get_clock() - - @defer.inlineCallbacks - @log_function - def get_new_events(self, user, from_key, room_ids=None, **kwargs): - from_key = int(from_key) - room_ids = room_ids or [] - - presence = self.hs.get_handlers().presence_handler - cachemap = presence._user_cachemap - - max_serial = presence._user_cachemap_latest_serial - - clock = self.clock - latest_serial = 0 - - user_ids_to_check = {user} - presence_list = yield presence.store.get_presence_list( - user.localpart, accepted=True - ) - if presence_list is not None: - user_ids_to_check |= set( - UserID.from_string(p["observed_user_id"]) for p in presence_list - ) - for room_id in set(room_ids) & set(presence._room_serials): - if presence._room_serials[room_id] > from_key: - joined = yield presence.get_joined_users_for_room_id(room_id) - user_ids_to_check |= set(joined) - - updates = [] - for observed_user in user_ids_to_check & set(cachemap): - cached = cachemap[observed_user] - - if cached.serial <= from_key or cached.serial > max_serial: - continue - - latest_serial = max(cached.serial, latest_serial) - updates.append(cached.make_event(user=observed_user, clock=clock)) - - # TODO(paul): limit - - for serial, user_ids in presence._remote_offline_serials: - if serial <= from_key: - break - - if serial > max_serial: - continue - - latest_serial = max(latest_serial, serial) - for u in user_ids: - updates.append({ - "type": "m.presence", - "content": {"user_id": u, "presence": PresenceState.OFFLINE}, - }) - # TODO(paul): For the v2 API we want to tell the client their from_key - # is too old if we fell off the end of the _remote_offline_serials - # list, and get them to invalidate+resync. In v1 we have no such - # concept so this is a best-effort result. - - if updates: - defer.returnValue((updates, latest_serial)) - else: - defer.returnValue(([], presence._user_cachemap_latest_serial)) - - def get_current_key(self): - presence = self.hs.get_handlers().presence_handler - return presence._user_cachemap_latest_serial - - @defer.inlineCallbacks - def get_pagination_rows(self, user, pagination_config, key): - # TODO (erikj): Does this make sense? Ordering? - - from_key = int(pagination_config.from_key) - - if pagination_config.to_key: - to_key = int(pagination_config.to_key) - else: - to_key = -1 - - presence = self.hs.get_handlers().presence_handler - cachemap = presence._user_cachemap + last_federate = new_state.last_federation_update_ts + if now - last_federate > FEDERATION_PING_INTERVAL: + # Been a while since we've poked remote servers + new_state = new_state.copy_and_replace( + last_federation_update_ts=now, + ) + federation_ping = True - user_ids_to_check = {user} - presence_list = yield presence.store.get_presence_list( - user.localpart, accepted=True + else: + wheel_timer.insert( + now=now, + obj=user_id, + then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT ) - if presence_list is not None: - user_ids_to_check |= set( - UserID.from_string(p["observed_user_id"]) for p in presence_list - ) - room_ids = yield presence.get_joined_rooms_for_user(user) - for room_id in set(room_ids) & set(presence._room_serials): - if presence._room_serials[room_id] >= from_key: - joined = yield presence.get_joined_users_for_room_id(room_id) - user_ids_to_check |= set(joined) - - updates = [] - for observed_user in user_ids_to_check & set(cachemap): - if not (to_key < cachemap[observed_user].serial <= from_key): - continue - - updates.append((observed_user, cachemap[observed_user])) - - # TODO(paul): limit - - if updates: - clock = self.clock - - earliest_serial = max([x[1].serial for x in updates]) - data = [x[1].make_event(user=x[0], clock=clock) for x in updates] - - defer.returnValue((data, earliest_serial)) - else: - defer.returnValue(([], 0)) - -class UserPresenceCache(object): - """Store an observed user's state and status message. - - Includes the update timestamp. - """ - def __init__(self): - self.state = {"presence": PresenceState.OFFLINE} - self.serial = None - - def __repr__(self): - return "UserPresenceCache(state=%r, serial=%r)" % ( - self.state, self.serial + # Check whether the change was something worth notifying about + if should_notify(prev_state, new_state): + new_state = new_state.copy_and_replace( + last_federation_update_ts=now, ) + persist_and_notify = True - def update(self, state, serial): - assert("mtime_age" not in state) - - self.state.update(state) - # Delete keys that are now 'None' - for k in self.state.keys(): - if self.state[k] is None: - del self.state[k] - - self.serial = serial - - if "status_msg" in state: - self.status_msg = state["status_msg"] - else: - self.status_msg = None - - def get_state(self): - # clone it so caller can't break our cache - state = dict(self.state) - return state - - def make_event(self, user, clock): - content = self.get_state() - content["user_id"] = user.to_string() - - if "last_active" in content: - content["last_active_ago"] = int( - clock.time_msec() - content.pop("last_active") - ) - - return {"type": "m.presence", "content": content} + return new_state, persist_and_notify, federation_ping diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 629e6e3594..c9ad5944e6 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -16,8 +16,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, AuthError, CodeMessageException -from synapse.api.constants import EventTypes, Membership -from synapse.types import UserID +from synapse.types import UserID, Requester from synapse.util import unwrapFirstError from ._base import BaseHandler @@ -49,6 +48,9 @@ class ProfileHandler(BaseHandler): distributor = hs.get_distributor() self.distributor = distributor + distributor.declare("collect_presencelike_data") + distributor.declare("changed_presencelike_data") + distributor.observe("registered_user", self.registered_user) distributor.observe( @@ -208,21 +210,18 @@ class ProfileHandler(BaseHandler): ) for j in joins: - content = { - "membership": Membership.JOIN, - } - - yield collect_presencelike_data(self.distributor, user, content) - - msg_handler = self.hs.get_handlers().message_handler + handler = self.hs.get_handlers().room_member_handler try: - yield msg_handler.create_and_send_event({ - "type": EventTypes.Member, - "room_id": j.room_id, - "state_key": user.to_string(), - "content": content, - "sender": user.to_string() - }, ratelimit=False) + # Assume the user isn't a guest because we don't let guests set + # profile or avatar data. + requester = Requester(user, "", False) + yield handler.update_membership( + requester, + user, + j.room_id, + "join", # We treat a profile update like a join. + ratelimit=False, # Try to hide that these events aren't atomic. + ) except Exception as e: logger.warn( "Failed to update join event for room %s - %s", diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index de4c694714..935c339707 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -36,8 +36,6 @@ class ReceiptsHandler(BaseHandler): ) self.clock = self.hs.get_clock() - self._receipt_cache = None - @defer.inlineCallbacks def received_client_receipt(self, room_id, receipt_type, user_id, event_id): diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 24c850ae9b..6d155d57e7 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -60,7 +60,7 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - yield self.check_user_id_is_valid(user_id) + yield self.check_user_id_not_appservice_exclusive(user_id) users = yield self.store.get_users_by_id_case_insensitive(user_id) if users: @@ -145,7 +145,7 @@ class RegistrationHandler(BaseHandler): localpart = yield self._generate_user_id(attempts > 0) user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - yield self.check_user_id_is_valid(user_id) + yield self.check_user_id_not_appservice_exclusive(user_id) if generate_token: token = self.auth_handler().generate_access_token(user_id) try: @@ -180,6 +180,11 @@ class RegistrationHandler(BaseHandler): 400, "Invalid user localpart for this application service.", errcode=Codes.EXCLUSIVE ) + + yield self.check_user_id_not_appservice_exclusive( + user_id, allowed_appservice=service + ) + token = self.auth_handler().generate_access_token(user_id) yield self.store.register( user_id=user_id, @@ -226,7 +231,7 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - yield self.check_user_id_is_valid(user_id) + yield self.check_user_id_not_appservice_exclusive(user_id) token = self.auth_handler().generate_access_token(user_id) try: yield self.store.register( @@ -278,12 +283,14 @@ class RegistrationHandler(BaseHandler): yield identity_handler.bind_threepid(c, user_id) @defer.inlineCallbacks - def check_user_id_is_valid(self, user_id): + def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None): # valid user IDs must not clash with any user ID namespaces claimed by # application services. services = yield self.store.get_app_services() interested_services = [ - s for s in services if s.is_interested_in_user(user_id) + s for s in services + if s.is_interested_in_user(user_id) + and s != allowed_appservice ] for service in interested_services: if service.is_exclusive_user(user_id): @@ -342,3 +349,18 @@ class RegistrationHandler(BaseHandler): def auth_handler(self): return self.hs.get_handlers().auth_handler + + @defer.inlineCallbacks + def guest_access_token_for(self, medium, address, inviter_user_id): + access_token = yield self.store.get_3pid_guest_access_token(medium, address) + if access_token: + defer.returnValue(access_token) + + _, access_token = yield self.register( + generate_token=True, + make_guest=True + ) + access_token = yield self.store.save_or_get_3pid_guest_access_token( + medium, address, access_token, inviter_user_id + ) + defer.returnValue(access_token) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index b2de2cd0c0..d2de23a6cc 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -24,7 +24,6 @@ from synapse.api.constants import ( ) from synapse.api.errors import AuthError, StoreError, SynapseError, Codes from synapse.util import stringutils, unwrapFirstError -from synapse.util.async import run_on_reactor from synapse.util.logcontext import preserve_context_over_fn from signedjson.sign import verify_signed_json @@ -42,10 +41,6 @@ logger = logging.getLogger(__name__) id_server_scheme = "https://" -def collect_presencelike_data(distributor, user, content): - return distributor.fire("collect_presencelike_data", user, content) - - def user_left_room(distributor, user, room_id): return preserve_context_over_fn( distributor.fire, @@ -81,20 +76,20 @@ class RoomCreationHandler(BaseHandler): } @defer.inlineCallbacks - def create_room(self, user_id, room_id, config): + def create_room(self, requester, config): """ Creates a new room. Args: - user_id (str): The ID of the user creating the new room. - room_id (str): The proposed ID for the new room. Can be None, in - which case one will be created for you. + requester (Requester): The user who requested the room creation. config (dict) : A dict of configuration options. Returns: The new room ID. Raises: - SynapseError if the room ID was taken, couldn't be stored, or - something went horribly wrong. + SynapseError if the room ID couldn't be stored, or something went + horribly wrong. """ + user_id = requester.user.to_string() + self.ratelimit(user_id) if "room_alias_name" in config: @@ -126,40 +121,28 @@ class RoomCreationHandler(BaseHandler): is_public = config.get("visibility", None) == "public" - if room_id: - # Ensure room_id is the correct type - room_id_obj = RoomID.from_string(room_id) - if not self.hs.is_mine(room_id_obj): - raise SynapseError(400, "Room id must be local") - - yield self.store.store_room( - room_id=room_id, - room_creator_user_id=user_id, - is_public=is_public - ) - else: - # autogen room IDs and try to create it. We may clash, so just - # try a few times till one goes through, giving up eventually. - attempts = 0 - room_id = None - while attempts < 5: - try: - random_string = stringutils.random_string(18) - gen_room_id = RoomID.create( - random_string, - self.hs.hostname, - ) - yield self.store.store_room( - room_id=gen_room_id.to_string(), - room_creator_user_id=user_id, - is_public=is_public - ) - room_id = gen_room_id.to_string() - break - except StoreError: - attempts += 1 - if not room_id: - raise StoreError(500, "Couldn't generate a room ID.") + # autogen room IDs and try to create it. We may clash, so just + # try a few times till one goes through, giving up eventually. + attempts = 0 + room_id = None + while attempts < 5: + try: + random_string = stringutils.random_string(18) + gen_room_id = RoomID.create( + random_string, + self.hs.hostname, + ) + yield self.store.store_room( + room_id=gen_room_id.to_string(), + room_creator_user_id=user_id, + is_public=is_public + ) + room_id = gen_room_id.to_string() + break + except StoreError: + attempts += 1 + if not room_id: + raise StoreError(500, "Couldn't generate a room ID.") if room_alias: directory_handler = self.hs.get_handlers().directory_handler @@ -185,9 +168,14 @@ class RoomCreationHandler(BaseHandler): creation_content = config.get("creation_content", {}) - user = UserID.from_string(user_id) - creation_events = self._create_events_for_new_room( - user, room_id, + msg_handler = self.hs.get_handlers().message_handler + room_member_handler = self.hs.get_handlers().room_member_handler + + yield self._send_events_for_new_room( + requester, + room_id, + msg_handler, + room_member_handler, preset_config=preset_config, invite_list=invite_list, initial_state=initial_state, @@ -195,14 +183,9 @@ class RoomCreationHandler(BaseHandler): room_alias=room_alias, ) - msg_handler = self.hs.get_handlers().message_handler - - for event in creation_events: - yield msg_handler.create_and_send_event(event, ratelimit=False) - if "name" in config: name = config["name"] - yield msg_handler.create_and_send_event({ + yield msg_handler.create_and_send_nonmember_event({ "type": EventTypes.Name, "room_id": room_id, "sender": user_id, @@ -212,7 +195,7 @@ class RoomCreationHandler(BaseHandler): if "topic" in config: topic = config["topic"] - yield msg_handler.create_and_send_event({ + yield msg_handler.create_and_send_nonmember_event({ "type": EventTypes.Topic, "room_id": room_id, "sender": user_id, @@ -221,13 +204,13 @@ class RoomCreationHandler(BaseHandler): }, ratelimit=False) for invitee in invite_list: - yield msg_handler.create_and_send_event({ - "type": EventTypes.Member, - "state_key": invitee, - "room_id": room_id, - "sender": user_id, - "content": {"membership": Membership.INVITE}, - }, ratelimit=False) + room_member_handler.update_membership( + requester, + UserID.from_string(invitee), + room_id, + "invite", + ratelimit=False, + ) for invite_3pid in invite_3pid_list: id_server = invite_3pid["id_server"] @@ -235,11 +218,11 @@ class RoomCreationHandler(BaseHandler): medium = invite_3pid["medium"] yield self.hs.get_handlers().room_member_handler.do_3pid_invite( room_id, - user, + requester.user, medium, address, id_server, - token_id=None, + requester, txn_id=None, ) @@ -253,19 +236,19 @@ class RoomCreationHandler(BaseHandler): defer.returnValue(result) - def _create_events_for_new_room(self, creator, room_id, preset_config, - invite_list, initial_state, creation_content, - room_alias): - config = RoomCreationHandler.PRESETS_DICT[preset_config] - - creator_id = creator.to_string() - - event_keys = { - "room_id": room_id, - "sender": creator_id, - "state_key": "", - } - + @defer.inlineCallbacks + def _send_events_for_new_room( + self, + creator, # A Requester object. + room_id, + msg_handler, + room_member_handler, + preset_config, + invite_list, + initial_state, + creation_content, + room_alias + ): def create(etype, content, **kwargs): e = { "type": etype, @@ -277,26 +260,39 @@ class RoomCreationHandler(BaseHandler): return e - creation_content.update({"creator": creator.to_string()}) - creation_event = create( + @defer.inlineCallbacks + def send(etype, content, **kwargs): + event = create(etype, content, **kwargs) + yield msg_handler.create_and_send_nonmember_event(event, ratelimit=False) + + config = RoomCreationHandler.PRESETS_DICT[preset_config] + + creator_id = creator.user.to_string() + + event_keys = { + "room_id": room_id, + "sender": creator_id, + "state_key": "", + } + + creation_content.update({"creator": creator_id}) + yield send( etype=EventTypes.Create, content=creation_content, ) - join_event = create( - etype=EventTypes.Member, - state_key=creator_id, - content={ - "membership": Membership.JOIN, - }, + yield room_member_handler.update_membership( + creator, + creator.user, + room_id, + "join", + ratelimit=False, ) - returned_events = [creation_event, join_event] - if (EventTypes.PowerLevels, '') not in initial_state: power_level_content = { "users": { - creator.to_string(): 100, + creator_id: 100, }, "users_default": 0, "events": { @@ -318,45 +314,35 @@ class RoomCreationHandler(BaseHandler): for invitee in invite_list: power_level_content["users"][invitee] = 100 - power_levels_event = create( + yield send( etype=EventTypes.PowerLevels, content=power_level_content, ) - returned_events.append(power_levels_event) - if room_alias and (EventTypes.CanonicalAlias, '') not in initial_state: - room_alias_event = create( + yield send( etype=EventTypes.CanonicalAlias, content={"alias": room_alias.to_string()}, ) - returned_events.append(room_alias_event) - if (EventTypes.JoinRules, '') not in initial_state: - join_rules_event = create( + yield send( etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]}, ) - returned_events.append(join_rules_event) - if (EventTypes.RoomHistoryVisibility, '') not in initial_state: - history_event = create( + yield send( etype=EventTypes.RoomHistoryVisibility, content={"history_visibility": config["history_visibility"]} ) - returned_events.append(history_event) - for (etype, state_key), content in initial_state.items(): - returned_events.append(create( + yield send( etype=etype, state_key=state_key, content=content, - )) - - return returned_events + ) class RoomMemberHandler(BaseHandler): @@ -404,16 +390,35 @@ class RoomMemberHandler(BaseHandler): remotedomains.add(member.domain) @defer.inlineCallbacks - def update_membership(self, requester, target, room_id, action, txn_id=None): + def update_membership( + self, + requester, + target, + room_id, + action, + txn_id=None, + remote_room_hosts=None, + third_party_signed=None, + ratelimit=True, + ): effective_membership_state = action if action in ["kick", "unban"]: effective_membership_state = "leave" elif action == "forget": effective_membership_state = "leave" + if third_party_signed is not None: + replication = self.hs.get_replication_layer() + yield replication.exchange_third_party_invite( + third_party_signed["sender"], + target.to_string(), + room_id, + third_party_signed, + ) + msg_handler = self.hs.get_handlers().message_handler - content = {"membership": unicode(effective_membership_state)} + content = {"membership": effective_membership_state} if requester.is_guest: content["kind"] = "guest" @@ -424,6 +429,9 @@ class RoomMemberHandler(BaseHandler): "room_id": room_id, "sender": requester.user.to_string(), "state_key": target.to_string(), + + # For backwards compatibility: + "membership": effective_membership_state, }, token_id=requester.access_token_id, txn_id=txn_id, @@ -444,202 +452,181 @@ class RoomMemberHandler(BaseHandler): errcode=Codes.BAD_STATE ) - yield msg_handler.send_event( + member_handler = self.hs.get_handlers().room_member_handler + yield member_handler.send_membership_event( event, context, - ratelimit=True, - is_guest=requester.is_guest + is_guest=requester.is_guest, + ratelimit=ratelimit, + remote_room_hosts=remote_room_hosts, + from_client=True, ) if action == "forget": yield self.forget(requester.user, room_id) @defer.inlineCallbacks - def send_membership_event(self, event, context, is_guest=False): - """ Change the membership status of a user in a room. + def send_membership_event( + self, + event, + context, + is_guest=False, + remote_room_hosts=None, + ratelimit=True, + from_client=True, + ): + """ + Change the membership status of a user in a room. Args: - event (SynapseEvent): The membership event + event (SynapseEvent): The membership event. + context: The context of the event. + is_guest (bool): Whether the sender is a guest. + room_hosts ([str]): Homeservers which are likely to already be in + the room, and could be danced with in order to join this + homeserver for the first time. + ratelimit (bool): Whether to rate limit this request. + from_client (bool): Whether this request is the result of a local + client request (rather than over federation). If so, we will + perform extra checks, like that this homeserver can act as this + client. Raises: SynapseError if there was a problem changing the membership. """ - target_user_id = event.state_key + target_user = UserID.from_string(event.state_key) + room_id = event.room_id - prev_state = context.current_state.get( - (EventTypes.Member, target_user_id), - None - ) + if from_client: + sender = UserID.from_string(event.sender) + assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) - room_id = event.room_id + message_handler = self.hs.get_handlers().message_handler + prev_event = message_handler.deduplicate_state_event(event, context) + if prev_event is not None: + return - # If we're trying to join a room then we have to do this differently - # if this HS is not currently in the room, i.e. we have to do the - # invite/join dance. - if event.membership == Membership.JOIN: - if is_guest: - guest_access = context.current_state.get( - (EventTypes.GuestAccess, ""), - None - ) - is_guest_access_allowed = ( - guest_access - and guest_access.content - and "guest_access" in guest_access.content - and guest_access.content["guest_access"] == "can_join" - ) - if not is_guest_access_allowed: - raise AuthError(403, "Guest access not allowed") + action = "send" - yield self._do_join(event, context) - else: - if event.membership == Membership.LEAVE: - is_host_in_room = yield self.is_host_in_room(room_id, context) - if not is_host_in_room: - # Rejecting an invite, rather than leaving a joined room - handler = self.hs.get_handlers().federation_handler - inviter = yield self.get_inviter(event) - if not inviter: - # return the same error as join_room_alias does - raise SynapseError(404, "No known servers") - yield handler.do_remotely_reject_invite( - [inviter.domain], - room_id, - event.user_id - ) - defer.returnValue({"room_id": room_id}) - return - - # FIXME: This isn't idempotency. - if prev_state and prev_state.membership == event.membership: - # double same action, treat this event as a NOOP. - defer.returnValue({}) - return - - yield self._do_local_membership_update( - event, - context=context, + if event.membership == Membership.JOIN: + if is_guest and not self._can_guest_join(context.current_state): + # This should be an auth check, but guests are a local concept, + # so don't really fit into the general auth process. + raise AuthError(403, "Guest access not allowed") + do_remote_join_dance, remote_room_hosts = self._should_do_dance( + context, + (self.get_inviter(event.state_key, context.current_state)), + remote_room_hosts, ) + if do_remote_join_dance: + action = "remote_join" + elif event.membership == Membership.LEAVE: + is_host_in_room = self.is_host_in_room(context.current_state) + if not is_host_in_room: + action = "remote_reject" - if prev_state and prev_state.membership == Membership.JOIN: - user = UserID.from_string(event.user_id) - user_left_room(self.distributor, user, event.room_id) + federation_handler = self.hs.get_handlers().federation_handler - defer.returnValue({"room_id": room_id}) + if action == "remote_join": + if len(remote_room_hosts) == 0: + raise SynapseError(404, "No known servers") - @defer.inlineCallbacks - def join_room_alias(self, joinee, room_alias, content={}): - directory_handler = self.hs.get_handlers().directory_handler - mapping = yield directory_handler.get_association(room_alias) + # We don't do an auth check if we are doing an invite + # join dance for now, since we're kinda implicitly checking + # that we are allowed to join when we decide whether or not we + # need to do the invite/join dance. + yield federation_handler.do_invite_join( + remote_room_hosts, + event.room_id, + event.user_id, + event.content, + ) + elif action == "remote_reject": + inviter = self.get_inviter(target_user.to_string(), context.current_state) + if not inviter: + raise SynapseError(404, "No known servers") + yield federation_handler.do_remotely_reject_invite( + [inviter.domain], + room_id, + event.user_id + ) + else: + yield self.handle_new_client_event( + event, + context, + extra_users=[target_user], + ratelimit=ratelimit, + ) - if not mapping: - raise SynapseError(404, "No such room alias") + prev_member_event = context.current_state.get( + (EventTypes.Member, target_user.to_string()), + None + ) - room_id = mapping["room_id"] - hosts = mapping["servers"] - if not hosts: - raise SynapseError(404, "No known servers") + if event.membership == Membership.JOIN: + if not prev_member_event or prev_member_event.membership != Membership.JOIN: + # Only fire user_joined_room if the user has acutally joined the + # room. Don't bother if the user is just changing their profile + # info. + yield user_joined_room(self.distributor, target_user, room_id) + elif event.membership == Membership.LEAVE: + if prev_member_event and prev_member_event.membership == Membership.JOIN: + user_left_room(self.distributor, target_user, room_id) + + def _can_guest_join(self, current_state): + """ + Returns whether a guest can join a room based on its current state. + """ + guest_access = current_state.get((EventTypes.GuestAccess, ""), None) + return ( + guest_access + and guest_access.content + and "guest_access" in guest_access.content + and guest_access.content["guest_access"] == "can_join" + ) - # If event doesn't include a display name, add one. - yield collect_presencelike_data(self.distributor, joinee, content) + def _should_do_dance(self, context, inviter, room_hosts=None): + # TODO: Shouldn't this be remote_room_host? + room_hosts = room_hosts or [] - content.update({"membership": Membership.JOIN}) - builder = self.event_builder_factory.new({ - "type": EventTypes.Member, - "state_key": joinee.to_string(), - "room_id": room_id, - "sender": joinee.to_string(), - "membership": Membership.JOIN, - "content": content, - }) - event, context = yield self._create_new_client_event(builder) + is_host_in_room = self.is_host_in_room(context.current_state) + if is_host_in_room: + return False, room_hosts - yield self._do_join(event, context, room_hosts=hosts) + if inviter and not self.hs.is_mine(inviter): + room_hosts.append(inviter.domain) - defer.returnValue({"room_id": room_id}) + return True, room_hosts @defer.inlineCallbacks - def _do_join(self, event, context, room_hosts=None): - room_id = event.room_id - - # XXX: We don't do an auth check if we are doing an invite - # join dance for now, since we're kinda implicitly checking - # that we are allowed to join when we decide whether or not we - # need to do the invite/join dance. - - is_host_in_room = yield self.is_host_in_room(room_id, context) - if is_host_in_room: - should_do_dance = False - elif room_hosts: # TODO: Shouldn't this be remote_room_host? - should_do_dance = True - else: - inviter = yield self.get_inviter(event) - if not inviter: - # return the same error as join_room_alias does - raise SynapseError(404, "No known servers") - should_do_dance = not self.hs.is_mine(inviter) - room_hosts = [inviter.domain] + def lookup_room_alias(self, room_alias): + """ + Get the room ID associated with a room alias. - if should_do_dance: - handler = self.hs.get_handlers().federation_handler - yield handler.do_invite_join( - room_hosts, - room_id, - event.user_id, - event.content, - ) - else: - logger.debug("Doing normal join") + Args: + room_alias (RoomAlias): The alias to look up. + Returns: + A tuple of: + The room ID as a RoomID object. + Hosts likely to be participating in the room ([str]). + Raises: + SynapseError if room alias could not be found. + """ + directory_handler = self.hs.get_handlers().directory_handler + mapping = yield directory_handler.get_association(room_alias) - yield self._do_local_membership_update( - event, - context=context, - ) + if not mapping: + raise SynapseError(404, "No such room alias") - prev_state = context.current_state.get((event.type, event.state_key)) - if not prev_state or prev_state.membership != Membership.JOIN: - # Only fire user_joined_room if the user has acutally joined the - # room. Don't bother if the user is just changing their profile - # info. - user = UserID.from_string(event.user_id) - yield user_joined_room(self.distributor, user, room_id) + room_id = mapping["room_id"] + servers = mapping["servers"] - @defer.inlineCallbacks - def get_inviter(self, event): - # TODO(markjh): get prev_state from snapshot - prev_state = yield self.store.get_room_member( - event.user_id, event.room_id - ) + defer.returnValue((RoomID.from_string(room_id), servers)) + def get_inviter(self, user_id, current_state): + prev_state = current_state.get((EventTypes.Member, user_id)) if prev_state and prev_state.membership == Membership.INVITE: - defer.returnValue(UserID.from_string(prev_state.user_id)) - return - elif "third_party_invite" in event.content: - if "sender" in event.content["third_party_invite"]: - inviter = UserID.from_string( - event.content["third_party_invite"]["sender"] - ) - defer.returnValue(inviter) - defer.returnValue(None) - - @defer.inlineCallbacks - def is_host_in_room(self, room_id, context): - is_host_in_room = yield self.auth.check_host_in_room( - room_id, - self.hs.hostname - ) - if not is_host_in_room: - # is *anyone* in the room? - room_member_keys = [ - v for (k, v) in context.current_state.keys() if ( - k == "m.room.member" - ) - ] - if len(room_member_keys) == 0: - # has the room been created so we can join it? - create_event = context.current_state.get(("m.room.create", "")) - if create_event: - is_host_in_room = True - defer.returnValue(is_host_in_room) + return UserID.from_string(prev_state.user_id) + return None @defer.inlineCallbacks def get_joined_rooms_for_user(self, user): @@ -657,18 +644,6 @@ class RoomMemberHandler(BaseHandler): defer.returnValue(room_ids) @defer.inlineCallbacks - def _do_local_membership_update(self, event, context): - yield run_on_reactor() - - target_user = UserID.from_string(event.state_key) - - yield self.handle_new_client_event( - event, - context, - extra_users=[target_user], - ) - - @defer.inlineCallbacks def do_3pid_invite( self, room_id, @@ -676,7 +651,7 @@ class RoomMemberHandler(BaseHandler): medium, address, id_server, - token_id, + requester, txn_id ): invitee = yield self._lookup_3pid( @@ -684,19 +659,12 @@ class RoomMemberHandler(BaseHandler): ) if invitee: - # make sure it looks like a user ID; it'll throw if it's invalid. - UserID.from_string(invitee) - yield self.hs.get_handlers().message_handler.create_and_send_event( - { - "type": EventTypes.Member, - "content": { - "membership": unicode("invite") - }, - "room_id": room_id, - "sender": inviter.to_string(), - "state_key": invitee, - }, - token_id=token_id, + handler = self.hs.get_handlers().room_member_handler + yield handler.update_membership( + requester, + UserID.from_string(invitee), + room_id, + "invite", txn_id=txn_id, ) else: @@ -706,7 +674,7 @@ class RoomMemberHandler(BaseHandler): address, room_id, inviter, - token_id, + requester.access_token_id, txn_id=txn_id ) @@ -801,7 +769,7 @@ class RoomMemberHandler(BaseHandler): if room_avatar_event: room_avatar_url = room_avatar_event.content.get("url", "") - token, public_key, key_validity_url, display_name = ( + token, public_keys, fallback_public_key, display_name = ( yield self._ask_id_server_for_third_party_invite( id_server=id_server, medium=medium, @@ -816,14 +784,18 @@ class RoomMemberHandler(BaseHandler): inviter_avatar_url=inviter_avatar_url ) ) + msg_handler = self.hs.get_handlers().message_handler - yield msg_handler.create_and_send_event( + yield msg_handler.create_and_send_nonmember_event( { "type": EventTypes.ThirdPartyInvite, "content": { "display_name": display_name, - "key_validity_url": key_validity_url, - "public_key": public_key, + "public_keys": public_keys, + + # For backwards compatibility: + "key_validity_url": fallback_public_key["key_validity_url"], + "public_key": fallback_public_key["public_key"], }, "room_id": room_id, "sender": user.to_string(), @@ -848,6 +820,41 @@ class RoomMemberHandler(BaseHandler): inviter_display_name, inviter_avatar_url ): + """ + Asks an identity server for a third party invite. + + :param id_server (str): hostname + optional port for the identity server. + :param medium (str): The literal string "email". + :param address (str): The third party address being invited. + :param room_id (str): The ID of the room to which the user is invited. + :param inviter_user_id (str): The user ID of the inviter. + :param room_alias (str): An alias for the room, for cosmetic + notifications. + :param room_avatar_url (str): The URL of the room's avatar, for cosmetic + notifications. + :param room_join_rules (str): The join rules of the email + (e.g. "public"). + :param room_name (str): The m.room.name of the room. + :param inviter_display_name (str): The current display name of the + inviter. + :param inviter_avatar_url (str): The URL of the inviter's avatar. + + :return: A deferred tuple containing: + token (str): The token which must be signed to prove authenticity. + public_keys ([{"public_key": str, "key_validity_url": str}]): + public_key is a base64-encoded ed25519 public key. + fallback_public_key: One element from public_keys. + display_name (str): A user-friendly name to represent the invited + user. + """ + + registration_handler = self.hs.get_handlers().registration_handler + guest_access_token = yield registration_handler.guest_access_token_for( + medium=medium, + address=address, + inviter_user_id=inviter_user_id, + ) + is_url = "%s%s/_matrix/identity/api/v1/store-invite" % ( id_server_scheme, id_server, ) @@ -864,16 +871,26 @@ class RoomMemberHandler(BaseHandler): "sender": inviter_user_id, "sender_display_name": inviter_display_name, "sender_avatar_url": inviter_avatar_url, + "guest_access_token": guest_access_token, } ) # TODO: Check for success token = data["token"] - public_key = data["public_key"] + public_keys = data.get("public_keys", []) + if "public_key" in data: + fallback_public_key = { + "public_key": data["public_key"], + "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( + id_server_scheme, id_server, + ), + } + else: + fallback_public_key = public_keys[0] + + if not public_keys: + public_keys.append(fallback_public_key) display_name = data["display_name"] - key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( - id_server_scheme, id_server, - ) - defer.returnValue((token, public_key, key_validity_url, display_name)) + defer.returnValue((token, public_keys, fallback_public_key, display_name)) def forget(self, user, room_id): return self.store.forget(user.to_string(), room_id) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 1d0f0058a2..fded6e4009 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -121,7 +121,11 @@ class SyncResult(collections.namedtuple("SyncResult", [ events. """ return bool( - self.presence or self.joined or self.invited or self.archived + self.presence or + self.joined or + self.invited or + self.archived or + self.account_data ) @@ -582,6 +586,28 @@ class SyncHandler(BaseHandler): if room_sync: joined.append(room_sync) + # For each newly joined room, we want to send down presence of + # existing users. + presence_handler = self.hs.get_handlers().presence_handler + extra_presence_users = set() + for room_id in newly_joined_rooms: + users = yield self.store.get_users_in_room(event.room_id) + extra_presence_users.update(users) + + # For each new member, send down presence. + for joined_sync in joined: + it = itertools.chain(joined_sync.timeline.events, joined_sync.state.values()) + for event in it: + if event.type == EventTypes.Member: + if event.membership == Membership.JOIN: + extra_presence_users.add(event.state_key) + + states = yield presence_handler.get_states( + [u for u in extra_presence_users if u != user_id], + as_event=True, + ) + presence.extend(states) + account_data_for_user = sync_config.filter_collection.filter_account_data( self.account_data_for_user(account_data) ) @@ -623,7 +649,6 @@ class SyncHandler(BaseHandler): recents = yield self._filter_events_for_client( sync_config.user.to_string(), recents, - is_peeking=sync_config.is_guest, ) else: recents = [] @@ -645,7 +670,6 @@ class SyncHandler(BaseHandler): loaded_recents = yield self._filter_events_for_client( sync_config.user.to_string(), loaded_recents, - is_peeking=sync_config.is_guest, ) loaded_recents.extend(recents) recents = loaded_recents @@ -825,14 +849,20 @@ class SyncHandler(BaseHandler): with Measure(self.clock, "compute_state_delta"): if full_state: if batch: + current_state = yield self.store.get_state_for_event( + batch.events[-1].event_id + ) + state = yield self.store.get_state_for_event( batch.events[0].event_id ) else: - state = yield self.get_state_at( + current_state = yield self.get_state_at( room_id, stream_position=now_token ) + state = current_state + timeline_state = { (event.type, event.state_key): event for event in batch.events if event.is_state() @@ -842,12 +872,17 @@ class SyncHandler(BaseHandler): timeline_contains=timeline_state, timeline_start=state, previous={}, + current=current_state, ) elif batch.limited: state_at_previous_sync = yield self.get_state_at( room_id, stream_position=since_token ) + current_state = yield self.store.get_state_for_event( + batch.events[-1].event_id + ) + state_at_timeline_start = yield self.store.get_state_for_event( batch.events[0].event_id ) @@ -861,6 +896,7 @@ class SyncHandler(BaseHandler): timeline_contains=timeline_state, timeline_start=state_at_timeline_start, previous=state_at_previous_sync, + current=current_state, ) else: state = {} @@ -920,7 +956,7 @@ def _action_has_highlight(actions): return False -def _calculate_state(timeline_contains, timeline_start, previous): +def _calculate_state(timeline_contains, timeline_start, previous, current): """Works out what state to include in a sync response. Args: @@ -928,6 +964,7 @@ def _calculate_state(timeline_contains, timeline_start, previous): timeline_start (dict): state at the start of the timeline previous (dict): state at the end of the previous sync (or empty dict if this is an initial sync) + current (dict): state at the end of the timeline Returns: dict @@ -938,14 +975,16 @@ def _calculate_state(timeline_contains, timeline_start, previous): timeline_contains.values(), previous.values(), timeline_start.values(), + current.values(), ) } + c_ids = set(e.event_id for e in current.values()) tc_ids = set(e.event_id for e in timeline_contains.values()) p_ids = set(e.event_id for e in previous.values()) ts_ids = set(e.event_id for e in timeline_start.values()) - state_ids = (ts_ids - p_ids) - tc_ids + state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids evs = (event_id_to_state[e] for e in state_ids) return { diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index b16d0017df..8ce27f49ec 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -25,6 +25,7 @@ from synapse.types import UserID import logging from collections import namedtuple +import ujson as json logger = logging.getLogger(__name__) @@ -219,6 +220,19 @@ class TypingNotificationHandler(BaseHandler): "typing_key", self._latest_room_serial, rooms=[room_id] ) + def get_all_typing_updates(self, last_id, current_id): + # TODO: Work out a way to do this without scanning the entire state. + rows = [] + for room_id, serial in self._room_serials.items(): + if last_id < serial and serial <= current_id: + typing = self._room_typing[room_id] + typing_bytes = json.dumps([ + u.to_string() for u in typing + ], ensure_ascii=False) + rows.append((serial, room_id, typing_bytes)) + rows.sort() + return rows + class TypingNotificationEventSource(object): def __init__(self, hs): diff --git a/synapse/http/server.py b/synapse/http/server.py index a90e2e1125..b17b190ee5 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -367,10 +367,29 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False, "Origin, X-Requested-With, Content-Type, Accept") request.write(json_bytes) - request.finish() + finish_request(request) return NOT_DONE_YET +def finish_request(request): + """ Finish writing the response to the request. + + Twisted throws a RuntimeException if the connection closed before the + response was written but doesn't provide a convenient or reliable way to + determine if the connection was closed. So we catch and log the RuntimeException + + You might think that ``request.notifyFinish`` could be used to tell if the + request was finished. However the deferred it returns won't fire if the + connection was already closed, meaning we'd have to have called the method + right at the start of the request. By the time we want to write the response + it will already be too late. + """ + try: + request.finish() + except RuntimeError as e: + logger.info("Connection disconnected before response was written: %r", e) + + def _request_user_agent_is_curl(request): user_agents = request.requestHeaders.getRawHeaders( "User-Agent", default=[] diff --git a/synapse/notifier.py b/synapse/notifier.py index 560866b26e..3c36a20868 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -159,6 +159,8 @@ class Notifier(object): self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS ) + self.replication_deferred = ObservableDeferred(defer.Deferred()) + # This is not a very cheap test to perform, but it's only executed # when rendering the metrics page, which is likely once per minute at # most when scraping it. @@ -207,6 +209,8 @@ class Notifier(object): )) self._notify_pending_new_room_events(max_room_stream_id) + self.notify_replication() + def _notify_pending_new_room_events(self, max_room_stream_id): """Notify for the room events that were queued waiting for a previous event to be persisted. @@ -276,6 +280,8 @@ class Notifier(object): except: logger.exception("Failed to notify listener") + self.notify_replication() + @defer.inlineCallbacks def wait_for_events(self, user_id, timeout, callback, room_ids=None, from_token=StreamToken("s0", "0", "0", "0", "0")): @@ -479,3 +485,45 @@ class Notifier(object): room_streams = self.room_to_user_streams.setdefault(room_id, set()) room_streams.add(new_user_stream) new_user_stream.rooms.add(room_id) + + def notify_replication(self): + """Notify the any replication listeners that there's a new event""" + with PreserveLoggingContext(): + deferred = self.replication_deferred + self.replication_deferred = ObservableDeferred(defer.Deferred()) + deferred.callback(None) + + @defer.inlineCallbacks + def wait_for_replication(self, callback, timeout): + """Wait for an event to happen. + + :param callback: + Gets called whenever an event happens. If this returns a truthy + value then ``wait_for_replication`` returns, otherwise it waits + for another event. + :param int timeout: + How many milliseconds to wait for callback return a truthy value. + :returns: + A deferred that resolves with the value returned by the callback. + """ + listener = _NotificationListener(None) + + def timed_out(): + listener.deferred.cancel() + + timer = self.clock.call_later(timeout / 1000., timed_out) + while True: + listener.deferred = self.replication_deferred.observe() + result = yield callback() + if result: + break + + try: + with PreserveLoggingContext(): + yield listener.deferred + except defer.CancelledError: + break + + self.clock.cancel_call_later(timer, ignore_errs=True) + + defer.returnValue(result) diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 8da2d8716c..4c6c3b83a2 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -47,14 +47,13 @@ class Pusher(object): MAX_BACKOFF = 60 * 60 * 1000 GIVE_UP_AFTER = 24 * 60 * 60 * 1000 - def __init__(self, _hs, profile_tag, user_id, app_id, + def __init__(self, _hs, user_id, app_id, app_display_name, device_display_name, pushkey, pushkey_ts, data, last_token, last_success, failing_since): self.hs = _hs self.evStreamHandler = self.hs.get_handlers().event_stream_handler self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() - self.profile_tag = profile_tag self.user_id = user_id self.app_id = app_id self.app_display_name = app_display_name @@ -186,8 +185,8 @@ class Pusher(object): processed = False rule_evaluator = yield \ - push_rule_evaluator.evaluator_for_user_id_and_profile_tag( - self.user_id, self.profile_tag, single_event['room_id'], self.store + push_rule_evaluator.evaluator_for_user_id( + self.user_id, single_event['room_id'], self.store ) actions = yield rule_evaluator.actions_for_event(single_event) diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index e0da0868ec..c6c1dc769e 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -44,5 +44,5 @@ class ActionGenerator: ) context.push_actions = [ - (uid, None, actions) for uid, actions in actions_by_user.items() + (uid, actions) for uid, actions in actions_by_user.items() ] diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 0832c77cb4..86a2998bcc 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -13,46 +13,67 @@ # limitations under the License. from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP +import copy def list_with_base_rules(rawrules): + """Combine the list of rules set by the user with the default push rules + + :param list rawrules: The rules the user has modified or set. + :returns: A new list with the rules set by the user combined with the + defaults. + """ ruleslist = [] + # Grab the base rules that the user has modified. + # The modified base rules have a priority_class of -1. + modified_base_rules = { + r['rule_id']: r for r in rawrules if r['priority_class'] < 0 + } + + # Remove the modified base rules from the list, They'll be added back + # in the default postions in the list. + rawrules = [r for r in rawrules if r['priority_class'] >= 0] + # shove the server default rules for each kind onto the end of each current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1] ruleslist.extend(make_base_prepend_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class] + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules )) for r in rawrules: if r['priority_class'] < current_prio_class: while r['priority_class'] < current_prio_class: ruleslist.extend(make_base_append_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class] + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, )) current_prio_class -= 1 if current_prio_class > 0: ruleslist.extend(make_base_prepend_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class] + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, )) ruleslist.append(r) while current_prio_class > 0: ruleslist.extend(make_base_append_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class] + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, )) current_prio_class -= 1 if current_prio_class > 0: ruleslist.extend(make_base_prepend_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class] + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, )) return ruleslist -def make_base_append_rules(kind): +def make_base_append_rules(kind, modified_base_rules): rules = [] if kind == 'override': @@ -62,15 +83,31 @@ def make_base_append_rules(kind): elif kind == 'content': rules = BASE_APPEND_CONTENT_RULES + # Copy the rules before modifying them + rules = copy.deepcopy(rules) + for r in rules: + # Only modify the actions, keep the conditions the same. + modified = modified_base_rules.get(r['rule_id']) + if modified: + r['actions'] = modified['actions'] + return rules -def make_base_prepend_rules(kind): +def make_base_prepend_rules(kind, modified_base_rules): rules = [] if kind == 'override': rules = BASE_PREPEND_OVERRIDE_RULES + # Copy the rules before modifying them + rules = copy.deepcopy(rules) + for r in rules: + # Only modify the actions, keep the conditions the same. + modified = modified_base_rules.get(r['rule_id']) + if modified: + r['actions'] = modified['actions'] + return rules @@ -263,18 +300,24 @@ BASE_APPEND_UNDERRIDE_RULES = [ ] +BASE_RULE_IDS = set() + for r in BASE_APPEND_CONTENT_RULES: r['priority_class'] = PRIORITY_CLASS_MAP['content'] r['default'] = True + BASE_RULE_IDS.add(r['rule_id']) for r in BASE_PREPEND_OVERRIDE_RULES: r['priority_class'] = PRIORITY_CLASS_MAP['override'] r['default'] = True + BASE_RULE_IDS.add(r['rule_id']) for r in BASE_APPEND_OVRRIDE_RULES: r['priority_class'] = PRIORITY_CLASS_MAP['override'] r['default'] = True + BASE_RULE_IDS.add(r['rule_id']) for r in BASE_APPEND_UNDERRIDE_RULES: r['priority_class'] = PRIORITY_CLASS_MAP['underride'] r['default'] = True + BASE_RULE_IDS.add(r['rule_id']) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 8ac5ceb9ef..5d8be483e5 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -103,7 +103,7 @@ class BulkPushRuleEvaluator: users_dict = yield self.store.are_guests(self.rules_by_user.keys()) - filtered_by_user = yield handler._filter_events_for_clients( + filtered_by_user = yield handler.filter_events_for_clients( users_dict.items(), [event], {event.event_id: current_state} ) @@ -152,7 +152,7 @@ def _condition_checker(evaluator, conditions, uid, display_name, cache): elif res is True: continue - res = evaluator.matches(cond, uid, display_name, None) + res = evaluator.matches(cond, uid, display_name) if _id: cache[_id] = bool(res) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index cdc4494928..9be4869360 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -23,12 +23,11 @@ logger = logging.getLogger(__name__) class HttpPusher(Pusher): - def __init__(self, _hs, profile_tag, user_id, app_id, + def __init__(self, _hs, user_id, app_id, app_display_name, device_display_name, pushkey, pushkey_ts, data, last_token, last_success, failing_since): super(HttpPusher, self).__init__( _hs, - profile_tag, user_id, app_id, app_display_name, diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 2a2b4437dc..98e2a2015e 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -33,7 +33,7 @@ INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") @defer.inlineCallbacks -def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store): +def evaluator_for_user_id(user_id, room_id, store): rawrules = yield store.get_push_rules_for_user(user_id) enabled_map = yield store.get_push_rules_enabled_for_user(user_id) our_member_event = yield store.get_current_state( @@ -43,7 +43,7 @@ def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store): ) defer.returnValue(PushRuleEvaluator( - user_id, profile_tag, rawrules, enabled_map, + user_id, rawrules, enabled_map, room_id, our_member_event, store )) @@ -77,10 +77,9 @@ def _room_member_count(ev, condition, room_member_count): class PushRuleEvaluator: DEFAULT_ACTIONS = [] - def __init__(self, user_id, profile_tag, raw_rules, enabled_map, room_id, + def __init__(self, user_id, raw_rules, enabled_map, room_id, our_member_event, store): self.user_id = user_id - self.profile_tag = profile_tag self.room_id = room_id self.our_member_event = our_member_event self.store = store @@ -152,7 +151,7 @@ class PushRuleEvaluator: matches = True for c in conditions: matches = evaluator.matches( - c, self.user_id, my_display_name, self.profile_tag + c, self.user_id, my_display_name ) if not matches: break @@ -189,13 +188,9 @@ class PushRuleEvaluatorForEvent(object): # Maps strings of e.g. 'content.body' -> event["content"]["body"] self._value_cache = _flatten_dict(event) - def matches(self, condition, user_id, display_name, profile_tag): + def matches(self, condition, user_id, display_name): if condition['kind'] == 'event_match': return self._event_match(condition, user_id) - elif condition['kind'] == 'device': - if 'profile_tag' not in condition: - return True - return condition['profile_tag'] == profile_tag elif condition['kind'] == 'contains_display_name': return self._contains_display_name(display_name) elif condition['kind'] == 'room_member_count': diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index d7dcb2de4b..a05aa5f661 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -29,6 +29,7 @@ class PusherPool: def __init__(self, _hs): self.hs = _hs self.store = self.hs.get_datastore() + self.clock = self.hs.get_clock() self.pushers = {} self.last_pusher_started = -1 @@ -38,8 +39,11 @@ class PusherPool: self._start_pushers(pushers) @defer.inlineCallbacks - def add_pusher(self, user_id, access_token, profile_tag, kind, app_id, - app_display_name, device_display_name, pushkey, lang, data): + def add_pusher(self, user_id, access_token, kind, app_id, + app_display_name, device_display_name, pushkey, lang, data, + profile_tag=""): + time_now_msec = self.clock.time_msec() + # we try to create the pusher just to validate the config: it # will then get pulled out of the database, # recreated, added and started: this means we have only one @@ -47,23 +51,31 @@ class PusherPool: self._create_pusher({ "user_name": user_id, "kind": kind, - "profile_tag": profile_tag, "app_id": app_id, "app_display_name": app_display_name, "device_display_name": device_display_name, "pushkey": pushkey, - "ts": self.hs.get_clock().time_msec(), + "ts": time_now_msec, "lang": lang, "data": data, "last_token": None, "last_success": None, "failing_since": None }) - yield self._add_pusher_to_store( - user_id, access_token, profile_tag, kind, app_id, - app_display_name, device_display_name, - pushkey, lang, data + yield self.store.add_pusher( + user_id=user_id, + access_token=access_token, + kind=kind, + app_id=app_id, + app_display_name=app_display_name, + device_display_name=device_display_name, + pushkey=pushkey, + pushkey_ts=time_now_msec, + lang=lang, + data=data, + profile_tag=profile_tag, ) + yield self._refresh_pusher(app_id, pushkey, user_id) @defer.inlineCallbacks def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey, @@ -94,30 +106,10 @@ class PusherPool: ) yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) - @defer.inlineCallbacks - def _add_pusher_to_store(self, user_id, access_token, profile_tag, kind, - app_id, app_display_name, device_display_name, - pushkey, lang, data): - yield self.store.add_pusher( - user_id=user_id, - access_token=access_token, - profile_tag=profile_tag, - kind=kind, - app_id=app_id, - app_display_name=app_display_name, - device_display_name=device_display_name, - pushkey=pushkey, - pushkey_ts=self.hs.get_clock().time_msec(), - lang=lang, - data=data, - ) - yield self._refresh_pusher(app_id, pushkey, user_id) - def _create_pusher(self, pusherdict): if pusherdict['kind'] == 'http': return HttpPusher( self.hs, - profile_tag=pusherdict['profile_tag'], user_id=pusherdict['user_name'], app_id=pusherdict['app_id'], app_display_name=pusherdict['app_display_name'], diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 75bf3d13aa..35933324a4 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) REQUIREMENTS = { "frozendict>=0.4": ["frozendict"], - "unpaddedbase64>=1.0.1": ["unpaddedbase64>=1.0.1"], + "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"], "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"], "signedjson>=1.0.0": ["signedjson>=1.0.0"], "pynacl==0.3.0": ["nacl==0.3.0", "nacl.bindings"], diff --git a/synapse/replication/__init__.py b/synapse/replication/__init__.py new file mode 100644 index 0000000000..b7df13c9ee --- /dev/null +++ b/synapse/replication/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py new file mode 100644 index 0000000000..e0d039518d --- /dev/null +++ b/synapse/replication/resource.py @@ -0,0 +1,320 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.http.servlet import parse_integer, parse_string +from synapse.http.server import request_handler, finish_request + +from twisted.web.resource import Resource +from twisted.web.server import NOT_DONE_YET +from twisted.internet import defer + +import ujson as json + +import collections +import logging + +logger = logging.getLogger(__name__) + +REPLICATION_PREFIX = "/_synapse/replication" + +STREAM_NAMES = ( + ("events",), + ("presence",), + ("typing",), + ("receipts",), + ("user_account_data", "room_account_data", "tag_account_data",), + ("backfill",), +) + + +class ReplicationResource(Resource): + """ + HTTP endpoint for extracting data from synapse. + + The streams of data returned by the endpoint are controlled by the + parameters given to the API. To return a given stream pass a query + parameter with a position in the stream to return data from or the + special value "-1" to return data from the start of the stream. + + If there is no data for any of the supplied streams after the given + position then the request will block until there is data for one + of the streams. This allows clients to long-poll this API. + + The possible streams are: + + * "streams": A special stream returing the positions of other streams. + * "events": The new events seen on the server. + * "presence": Presence updates. + * "typing": Typing updates. + * "receipts": Receipt updates. + * "user_account_data": Top-level per user account data. + * "room_account_data: Per room per user account data. + * "tag_account_data": Per room per user tags. + * "backfill": Old events that have been backfilled from other servers. + + The API takes two additional query parameters: + + * "timeout": How long to wait before returning an empty response. + * "limit": The maximum number of rows to return for the selected streams. + + The response is a JSON object with keys for each stream with updates. Under + each key is a JSON object with: + + * "postion": The current position of the stream. + * "field_names": The names of the fields in each row. + * "rows": The updates as an array of arrays. + + There are a number of ways this API could be used: + + 1) To replicate the contents of the backing database to another database. + 2) To be notified when the contents of a shared backing database changes. + 3) To "tail" the activity happening on a server for debugging. + + In the first case the client would track all of the streams and store it's + own copy of the data. + + In the second case the client might theoretically just be able to follow + the "streams" stream to track where the other streams are. However in + practise it will probably need to get the contents of the streams in + order to expire the any in-memory caches. Whether it gets the contents + of the streams from this replication API or directly from the backing + store is a matter of taste. + + In the third case the client would use the "streams" stream to find what + streams are available and their current positions. Then it can start + long-polling this replication API for new data on those streams. + """ + + isLeaf = True + + def __init__(self, hs): + Resource.__init__(self) # Resource is old-style, so no super() + + self.version_string = hs.version_string + self.store = hs.get_datastore() + self.sources = hs.get_event_sources() + self.presence_handler = hs.get_handlers().presence_handler + self.typing_handler = hs.get_handlers().typing_notification_handler + self.notifier = hs.notifier + + def render_GET(self, request): + self._async_render_GET(request) + return NOT_DONE_YET + + @defer.inlineCallbacks + def current_replication_token(self): + stream_token = yield self.sources.get_current_token() + backfill_token = yield self.store.get_current_backfill_token() + + defer.returnValue(_ReplicationToken( + stream_token.room_stream_id, + int(stream_token.presence_key), + int(stream_token.typing_key), + int(stream_token.receipt_key), + int(stream_token.account_data_key), + backfill_token, + )) + + @request_handler + @defer.inlineCallbacks + def _async_render_GET(self, request): + limit = parse_integer(request, "limit", 100) + timeout = parse_integer(request, "timeout", 10 * 1000) + + request.setHeader(b"Content-Type", b"application/json") + writer = _Writer(request) + + @defer.inlineCallbacks + def replicate(): + current_token = yield self.current_replication_token() + logger.info("Replicating up to %r", current_token) + + yield self.account_data(writer, current_token, limit) + yield self.events(writer, current_token, limit) + yield self.presence(writer, current_token) # TODO: implement limit + yield self.typing(writer, current_token) # TODO: implement limit + yield self.receipts(writer, current_token, limit) + self.streams(writer, current_token) + + logger.info("Replicated %d rows", writer.total) + defer.returnValue(writer.total) + + yield self.notifier.wait_for_replication(replicate, timeout) + + writer.finish() + + def streams(self, writer, current_token): + request_token = parse_string(writer.request, "streams") + + streams = [] + + if request_token is not None: + if request_token == "-1": + for names, position in zip(STREAM_NAMES, current_token): + streams.extend((name, position) for name in names) + else: + items = zip( + STREAM_NAMES, + current_token, + _ReplicationToken(request_token) + ) + for names, current_id, last_id in items: + if last_id < current_id: + streams.extend((name, current_id) for name in names) + + if streams: + writer.write_header_and_rows( + "streams", streams, ("name", "position"), + position=str(current_token) + ) + + @defer.inlineCallbacks + def events(self, writer, current_token, limit): + request_events = parse_integer(writer.request, "events") + request_backfill = parse_integer(writer.request, "backfill") + + if request_events is not None or request_backfill is not None: + if request_events is None: + request_events = current_token.events + if request_backfill is None: + request_backfill = current_token.backfill + events_rows, backfill_rows = yield self.store.get_all_new_events( + request_backfill, request_events, + current_token.backfill, current_token.events, + limit + ) + writer.write_header_and_rows( + "events", events_rows, ("position", "internal", "json") + ) + writer.write_header_and_rows( + "backfill", backfill_rows, ("position", "internal", "json") + ) + + @defer.inlineCallbacks + def presence(self, writer, current_token): + current_position = current_token.presence + + request_presence = parse_integer(writer.request, "presence") + + if request_presence is not None: + presence_rows = yield self.presence_handler.get_all_presence_updates( + request_presence, current_position + ) + writer.write_header_and_rows("presence", presence_rows, ( + "position", "user_id", "state", "last_active_ts", + "last_federation_update_ts", "last_user_sync_ts", + "status_msg", "currently_active", + )) + + @defer.inlineCallbacks + def typing(self, writer, current_token): + current_position = current_token.presence + + request_typing = parse_integer(writer.request, "typing") + + if request_typing is not None: + typing_rows = yield self.typing_handler.get_all_typing_updates( + request_typing, current_position + ) + writer.write_header_and_rows("typing", typing_rows, ( + "position", "room_id", "typing" + )) + + @defer.inlineCallbacks + def receipts(self, writer, current_token, limit): + current_position = current_token.receipts + + request_receipts = parse_integer(writer.request, "receipts") + + if request_receipts is not None: + receipts_rows = yield self.store.get_all_updated_receipts( + request_receipts, current_position, limit + ) + writer.write_header_and_rows("receipts", receipts_rows, ( + "position", "room_id", "receipt_type", "user_id", "event_id", "data" + )) + + @defer.inlineCallbacks + def account_data(self, writer, current_token, limit): + current_position = current_token.account_data + + user_account_data = parse_integer(writer.request, "user_account_data") + room_account_data = parse_integer(writer.request, "room_account_data") + tag_account_data = parse_integer(writer.request, "tag_account_data") + + if user_account_data is not None or room_account_data is not None: + if user_account_data is None: + user_account_data = current_position + if room_account_data is None: + room_account_data = current_position + user_rows, room_rows = yield self.store.get_all_updated_account_data( + user_account_data, room_account_data, current_position, limit + ) + writer.write_header_and_rows("user_account_data", user_rows, ( + "position", "user_id", "type", "content" + )) + writer.write_header_and_rows("room_account_data", room_rows, ( + "position", "user_id", "room_id", "type", "content" + )) + + if tag_account_data is not None: + tag_rows = yield self.store.get_all_updated_tags( + tag_account_data, current_position, limit + ) + writer.write_header_and_rows("tag_account_data", tag_rows, ( + "position", "user_id", "room_id", "tags" + )) + + +class _Writer(object): + """Writes the streams as a JSON object as the response to the request""" + def __init__(self, request): + self.streams = {} + self.request = request + self.total = 0 + + def write_header_and_rows(self, name, rows, fields, position=None): + if not rows: + return + + if position is None: + position = rows[-1][0] + + self.streams[name] = { + "position": str(position), + "field_names": fields, + "rows": rows, + } + + self.total += len(rows) + + def finish(self): + self.request.write(json.dumps(self.streams, ensure_ascii=False)) + finish_request(self.request) + + +class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( + "events", "presence", "typing", "receipts", "account_data", "backfill", +))): + __slots__ = [] + + def __new__(cls, *args): + if len(args) == 1: + return cls(*(int(value) for value in args[0].split("_"))) + else: + return super(_ReplicationToken, cls).__new__(cls, *args) + + def __str__(self): + return "_".join(str(value) for value in self) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 7199113dac..f13272da8e 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -17,6 +17,8 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, LoginError, Codes from synapse.types import UserID +from synapse.http.server import finish_request + from base import ClientV1RestServlet, client_path_patterns import simplejson as json @@ -263,7 +265,7 @@ class SAML2RestServlet(ClientV1RestServlet): '?status=authenticated&access_token=' + token + '&user_id=' + user_id + '&ava=' + urllib.quote(json.dumps(saml2_auth.ava))) - request.finish() + finish_request(request) defer.returnValue(None) defer.returnValue((200, {"status": "authenticated", "user_id": user_id, "token": token, @@ -272,7 +274,7 @@ class SAML2RestServlet(ClientV1RestServlet): request.redirect(urllib.unquote( request.args['RelayState'][0]) + '?status=not_authenticated') - request.finish() + finish_request(request) defer.returnValue(None) defer.returnValue((200, {"status": "not_authenticated"})) @@ -309,7 +311,7 @@ class CasRedirectServlet(ClientV1RestServlet): "service": "%s?%s" % (hs_redirect_url, client_redirect_url_param) }) request.redirect("%s?%s" % (self.cas_server_url, service_param)) - request.finish() + finish_request(request) class CasTicketServlet(ClientV1RestServlet): @@ -362,7 +364,7 @@ class CasTicketServlet(ClientV1RestServlet): redirect_url = self.add_login_token_to_redirect_url(client_redirect_url, login_token) request.redirect(redirect_url) - request.finish() + finish_request(request) def add_login_token_to_redirect_url(self, url, token): url_parts = list(urlparse.urlparse(url)) @@ -402,10 +404,12 @@ def _parse_json(request): try: content = json.loads(request.content.read()) if type(content) != dict: - raise SynapseError(400, "Content must be a JSON object.") + raise SynapseError( + 400, "Content must be a JSON object.", errcode=Codes.BAD_JSON + ) return content except ValueError: - raise SynapseError(400, "Content not JSON.") + raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index a6f8754e32..bbfa1d6ac4 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -17,7 +17,7 @@ """ from twisted.internet import defer -from synapse.api.errors import SynapseError +from synapse.api.errors import SynapseError, AuthError from synapse.types import UserID from .base import ClientV1RestServlet, client_path_patterns @@ -35,8 +35,15 @@ class PresenceStatusRestServlet(ClientV1RestServlet): requester = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) - state = yield self.handlers.presence_handler.get_state( - target_user=user, auth_user=requester.user) + if requester.user != user: + allowed = yield self.handlers.presence_handler.is_visible( + observed_user=user, observer_user=requester.user, + ) + + if not allowed: + raise AuthError(403, "You are not allowed to see their presence.") + + state = yield self.handlers.presence_handler.get_state(target_user=user) defer.returnValue((200, state)) @@ -45,6 +52,9 @@ class PresenceStatusRestServlet(ClientV1RestServlet): requester = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) + if requester.user != user: + raise AuthError(403, "Can only set your own presence state") + state = {} try: content = json.loads(request.content.read()) @@ -63,8 +73,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): except: raise SynapseError(400, "Unable to parse state") - yield self.handlers.presence_handler.set_state( - target_user=user, auth_user=requester.user, state=state) + yield self.handlers.presence_handler.set_state(user, state) defer.returnValue((200, {})) @@ -87,11 +96,8 @@ class PresenceListRestServlet(ClientV1RestServlet): raise SynapseError(400, "Cannot get another user's presence list") presence = yield self.handlers.presence_handler.get_presence_list( - observer_user=user, accepted=True) - - for p in presence: - observed_user = p.pop("observed_user") - p["user_id"] = observed_user.to_string() + observer_user=user, accepted=True + ) defer.returnValue((200, presence)) diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 96633a176c..970a019223 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -22,7 +22,7 @@ from .base import ClientV1RestServlet, client_path_patterns from synapse.storage.push_rule import ( InconsistentRuleException, RuleNotFoundException ) -import synapse.push.baserules as baserules +from synapse.push.baserules import list_with_base_rules, BASE_RULE_IDS from synapse.push.rulekinds import ( PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP ) @@ -55,12 +55,15 @@ class PushRuleRestServlet(ClientV1RestServlet): yield self.set_rule_attr(requester.user.to_string(), spec, content) defer.returnValue((200, {})) + if spec['rule_id'].startswith('.'): + # Rule ids starting with '.' are reserved for server default rules. + raise SynapseError(400, "cannot add new rule_ids that start with '.'") + try: (conditions, actions) = _rule_tuple_from_request_object( spec['template'], spec['rule_id'], content, - device=spec['device'] if 'device' in spec else None ) except InvalidRuleException as e: raise SynapseError(400, e.message) @@ -129,7 +132,7 @@ class PushRuleRestServlet(ClientV1RestServlet): ruleslist.append(rule) # We're going to be mutating this a lot, so do a deep copy - ruleslist = copy.deepcopy(baserules.list_with_base_rules(ruleslist)) + ruleslist = copy.deepcopy(list_with_base_rules(ruleslist)) rules = {'global': {}, 'device': {}} @@ -153,23 +156,7 @@ class PushRuleRestServlet(ClientV1RestServlet): elif pattern_type == "user_localpart": c["pattern"] = user.localpart - if r['priority_class'] > PRIORITY_CLASS_MAP['override']: - # per-device rule - profile_tag = _profile_tag_from_conditions(r["conditions"]) - r = _strip_device_condition(r) - if not profile_tag: - continue - if profile_tag not in rules['device']: - rules['device'][profile_tag] = {} - rules['device'][profile_tag] = ( - _add_empty_priority_class_arrays( - rules['device'][profile_tag] - ) - ) - - rulearray = rules['device'][profile_tag][template_name] - else: - rulearray = rules['global'][template_name] + rulearray = rules['global'][template_name] template_rule = _rule_to_template(r) if template_rule: @@ -195,24 +182,6 @@ class PushRuleRestServlet(ClientV1RestServlet): path = path[1:] result = _filter_ruleset_with_path(rules['global'], path) defer.returnValue((200, result)) - elif path[0] == 'device': - path = path[1:] - if path == []: - raise UnrecognizedRequestError( - PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR - ) - if path[0] == '': - defer.returnValue((200, rules['device'])) - - profile_tag = path[0] - path = path[1:] - if profile_tag not in rules['device']: - ret = {} - ret = _add_empty_priority_class_arrays(ret) - defer.returnValue((200, ret)) - ruleset = rules['device'][profile_tag] - result = _filter_ruleset_with_path(ruleset, path) - defer.returnValue((200, result)) else: raise UnrecognizedRequestError() @@ -232,13 +201,17 @@ class PushRuleRestServlet(ClientV1RestServlet): return self.hs.get_datastore().set_push_rule_enabled( user_id, namespaced_rule_id, val ) - else: - raise UnrecognizedRequestError() - - def get_rule_attr(self, user_id, namespaced_rule_id, attr): - if attr == 'enabled': - return self.hs.get_datastore().get_push_rule_enabled_by_user_rule_id( - user_id, namespaced_rule_id + elif spec['attr'] == 'actions': + actions = val.get('actions') + _check_actions(actions) + namespaced_rule_id = _namespaced_rule_id_from_spec(spec) + rule_id = spec['rule_id'] + is_default_rule = rule_id.startswith(".") + if is_default_rule: + if namespaced_rule_id not in BASE_RULE_IDS: + raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) + return self.hs.get_datastore().set_push_rule_actions( + user_id, namespaced_rule_id, actions, is_default_rule ) else: raise UnrecognizedRequestError() @@ -252,16 +225,9 @@ def _rule_spec_from_path(path): scope = path[1] path = path[2:] - if scope not in ['global', 'device']: + if scope != 'global': raise UnrecognizedRequestError() - device = None - if scope == 'device': - if len(path) == 0: - raise UnrecognizedRequestError() - device = path[0] - path = path[1:] - if len(path) == 0: raise UnrecognizedRequestError() @@ -278,8 +244,6 @@ def _rule_spec_from_path(path): 'template': template, 'rule_id': rule_id } - if device: - spec['profile_tag'] = device path = path[1:] @@ -289,7 +253,7 @@ def _rule_spec_from_path(path): return spec -def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None): +def _rule_tuple_from_request_object(rule_template, rule_id, req_obj): if rule_template in ['override', 'underride']: if 'conditions' not in req_obj: raise InvalidRuleException("Missing 'conditions'") @@ -322,16 +286,19 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None else: raise InvalidRuleException("Unknown rule template: %s" % (rule_template,)) - if device: - conditions.append({ - 'kind': 'device', - 'profile_tag': device - }) - if 'actions' not in req_obj: raise InvalidRuleException("No actions found") actions = req_obj['actions'] + _check_actions(actions) + + return conditions, actions + + +def _check_actions(actions): + if not isinstance(actions, list): + raise InvalidRuleException("No actions found") + for a in actions: if a in ['notify', 'dont_notify', 'coalesce']: pass @@ -340,8 +307,6 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None else: raise InvalidRuleException("Unrecognised action") - return conditions, actions - def _add_empty_priority_class_arrays(d): for pc in PRIORITY_CLASS_MAP.keys(): @@ -349,17 +314,6 @@ def _add_empty_priority_class_arrays(d): return d -def _profile_tag_from_conditions(conditions): - """ - Given a list of conditions, return the profile tag of the - device rule if there is one - """ - for c in conditions: - if c['kind'] == 'device': - return c['profile_tag'] - return None - - def _filter_ruleset_with_path(ruleset, path): if path == []: raise UnrecognizedRequestError( @@ -393,29 +347,23 @@ def _filter_ruleset_with_path(ruleset, path): attr = path[0] if attr in the_rule: - return the_rule[attr] + # Make sure we return a JSON object as the attribute may be a + # JSON value. + return {attr: the_rule[attr]} else: raise UnrecognizedRequestError() def _priority_class_from_spec(spec): if spec['template'] not in PRIORITY_CLASS_MAP.keys(): - raise InvalidRuleException("Unknown template: %s" % (spec['kind'])) + raise InvalidRuleException("Unknown template: %s" % (spec['template'])) pc = PRIORITY_CLASS_MAP[spec['template']] - if spec['scope'] == 'device': - pc += len(PRIORITY_CLASS_MAP) - return pc def _priority_class_to_template_name(pc): - if pc > PRIORITY_CLASS_MAP['override']: - # per-device - prio_class_index = pc - len(PRIORITY_CLASS_MAP) - return PRIORITY_CLASS_INVERSE_MAP[prio_class_index] - else: - return PRIORITY_CLASS_INVERSE_MAP[pc] + return PRIORITY_CLASS_INVERSE_MAP[pc] def _rule_to_template(rule): @@ -445,23 +393,12 @@ def _rule_to_template(rule): return templaterule -def _strip_device_condition(rule): - for i, c in enumerate(rule['conditions']): - if c['kind'] == 'device': - del rule['conditions'][i] - return rule - - def _namespaced_rule_id_from_spec(spec): return _namespaced_rule_id(spec, spec['rule_id']) def _namespaced_rule_id(spec, rule_id): - if spec['scope'] == 'global': - scope = 'global' - else: - scope = 'device/%s' % (spec['profile_tag']) - return "%s/%s/%s" % (scope, spec['template'], rule_id) + return "global/%s/%s" % (spec['template'], rule_id) def _rule_id_from_namespaced(in_rule_id): diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 5547f1b112..4c662e6e3c 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -45,7 +45,7 @@ class PusherRestServlet(ClientV1RestServlet): ) defer.returnValue((200, {})) - reqd = ['profile_tag', 'kind', 'app_id', 'app_display_name', + reqd = ['kind', 'app_id', 'app_display_name', 'device_display_name', 'pushkey', 'lang', 'data'] missing = [] for i in reqd: @@ -73,14 +73,14 @@ class PusherRestServlet(ClientV1RestServlet): yield pusher_pool.add_pusher( user_id=user.to_string(), access_token=requester.access_token_id, - profile_tag=content['profile_tag'], kind=content['kind'], app_id=content['app_id'], app_display_name=content['app_display_name'], device_display_name=content['device_display_name'], pushkey=content['pushkey'], lang=content['lang'], - data=content['data'] + data=content['data'], + profile_tag=content.get('profile_tag', ""), ) except PusherConfigException as pce: raise SynapseError(400, "Config Error: " + pce.message, diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 81bfe377bd..f5ed4f7302 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -63,24 +63,12 @@ class RoomCreateRestServlet(ClientV1RestServlet): def on_POST(self, request): requester = yield self.auth.get_user_by_req(request) - room_config = self.get_room_config(request) - info = yield self.make_room( - room_config, - requester.user, - None, - ) - room_config.update(info) - defer.returnValue((200, info)) - - @defer.inlineCallbacks - def make_room(self, room_config, auth_user, room_id): handler = self.handlers.room_creation_handler info = yield handler.create_room( - user_id=auth_user.to_string(), - room_id=room_id, - config=room_config + requester, self.get_room_config(request) ) - defer.returnValue(info) + + defer.returnValue((200, info)) def get_room_config(self, request): try: @@ -162,11 +150,22 @@ class RoomStateEventRestServlet(ClientV1RestServlet): event_dict["state_key"] = state_key msg_handler = self.handlers.message_handler - yield msg_handler.create_and_send_event( - event_dict, token_id=requester.access_token_id, txn_id=txn_id, + event, context = yield msg_handler.create_event( + event_dict, + token_id=requester.access_token_id, + txn_id=txn_id, ) - defer.returnValue((200, {})) + if event_type == EventTypes.Member: + yield self.handlers.room_member_handler.send_membership_event( + event, + context, + is_guest=requester.is_guest, + ) + else: + yield msg_handler.send_nonmember_event(event, context) + + defer.returnValue((200, {"event_id": event.event_id})) # TODO: Needs unit testing for generic events + feedback @@ -183,7 +182,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet): content = _parse_json(request) msg_handler = self.handlers.message_handler - event = yield msg_handler.create_and_send_event( + event = yield msg_handler.create_and_send_nonmember_event( { "type": event_type, "content": content, @@ -229,46 +228,37 @@ class JoinRoomAliasServlet(ClientV1RestServlet): allow_guest=True, ) - # the identifier could be a room alias or a room id. Try one then the - # other if it fails to parse, without swallowing other valid - # SynapseErrors. - - identifier = None - is_room_alias = False try: - identifier = RoomAlias.from_string(room_identifier) - is_room_alias = True - except SynapseError: - identifier = RoomID.from_string(room_identifier) - - # TODO: Support for specifying the home server to join with? - - if is_room_alias: + content = _parse_json(request) + except: + # Turns out we used to ignore the body entirely, and some clients + # cheekily send invalid bodies. + content = {} + + if RoomID.is_valid(room_identifier): + room_id = room_identifier + remote_room_hosts = None + elif RoomAlias.is_valid(room_identifier): handler = self.handlers.room_member_handler - ret_dict = yield handler.join_room_alias( - requester.user, - identifier, - ) - defer.returnValue((200, ret_dict)) - else: # room id - msg_handler = self.handlers.message_handler - content = {"membership": Membership.JOIN} - if requester.is_guest: - content["kind"] = "guest" - yield msg_handler.create_and_send_event( - { - "type": EventTypes.Member, - "content": content, - "room_id": identifier.to_string(), - "sender": requester.user.to_string(), - "state_key": requester.user.to_string(), - }, - token_id=requester.access_token_id, - txn_id=txn_id, - is_guest=requester.is_guest, - ) + room_alias = RoomAlias.from_string(room_identifier) + room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias) + room_id = room_id.to_string() + else: + raise SynapseError(400, "%s was not legal room ID or room alias" % ( + room_identifier, + )) - defer.returnValue((200, {"room_id": identifier.to_string()})) + yield self.handlers.room_member_handler.update_membership( + requester=requester, + target=requester.user, + room_id=room_id, + action="join", + txn_id=txn_id, + remote_room_hosts=remote_room_hosts, + third_party_signed=content.get("third_party_signed", None), + ) + + defer.returnValue((200, {"room_id": room_id})) @defer.inlineCallbacks def on_PUT(self, request, room_identifier, txn_id): @@ -316,18 +306,6 @@ class RoomMemberListRestServlet(ClientV1RestServlet): if event["type"] != EventTypes.Member: continue chunk.append(event) - # FIXME: should probably be state_key here, not user_id - target_user = UserID.from_string(event["user_id"]) - # Presence is an optional cache; don't fail if we can't fetch it - try: - presence_handler = self.handlers.presence_handler - presence_state = yield presence_handler.get_state( - target_user=target_user, - auth_user=requester.user, - ) - event["content"].update(presence_state) - except: - pass defer.returnValue((200, { "chunk": chunk @@ -454,7 +432,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet): }: raise AuthError(403, "Guest access not allowed") - content = _parse_json(request) + try: + content = _parse_json(request) + except: + # Turns out we used to ignore the body entirely, and some clients + # cheekily send invalid bodies. + content = {} if membership_action == "invite" and self._has_3pid_invite_keys(content): yield self.handlers.room_member_handler.do_3pid_invite( @@ -463,7 +446,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): content["medium"], content["address"], content["id_server"], - requester.access_token_id, + requester, txn_id ) defer.returnValue((200, {})) @@ -481,6 +464,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): room_id=room_id, action=membership_action, txn_id=txn_id, + third_party_signed=content.get("third_party_signed", None), ) defer.returnValue((200, {})) @@ -519,7 +503,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): content = _parse_json(request) msg_handler = self.handlers.message_handler - event = yield msg_handler.create_and_send_event( + event = yield msg_handler.create_and_send_nonmember_event( { "type": EventTypes.Redaction, "content": content, @@ -553,6 +537,10 @@ class RoomTypingRestServlet(ClientV1RestServlet): "/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$" ) + def __init__(self, hs): + super(RoomTypingRestServlet, self).__init__(hs) + self.presence_handler = hs.get_handlers().presence_handler + @defer.inlineCallbacks def on_PUT(self, request, room_id, user_id): requester = yield self.auth.get_user_by_req(request) @@ -564,6 +552,8 @@ class RoomTypingRestServlet(ClientV1RestServlet): typing_handler = self.handlers.typing_notification_handler + yield self.presence_handler.bump_presence_active_time(requester.user) + if content["typing"]: yield typing_handler.started_typing( target_user=target_user, diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index ff71c40b43..78181b7b18 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.api.constants import LoginType from synapse.api.errors import SynapseError from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX +from synapse.http.server import finish_request from synapse.http.servlet import RestServlet from ._base import client_v2_patterns @@ -130,7 +131,7 @@ class AuthRestServlet(RestServlet): request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) request.write(html_bytes) - request.finish() + finish_request(request) defer.returnValue(None) else: raise SynapseError(404, "Unknown auth stage type") @@ -176,7 +177,7 @@ class AuthRestServlet(RestServlet): request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) request.write(html_bytes) - request.finish() + finish_request(request) defer.returnValue(None) else: diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index eb4b369a3d..b831d8c95e 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -37,6 +37,7 @@ class ReceiptRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.receipts_handler = hs.get_handlers().receipts_handler + self.presence_handler = hs.get_handlers().presence_handler @defer.inlineCallbacks def on_POST(self, request, room_id, receipt_type, event_id): @@ -45,6 +46,8 @@ class ReceiptRestServlet(RestServlet): if receipt_type != "m.read": raise SynapseError(400, "Receipt type must be 'm.read'") + yield self.presence_handler.bump_presence_active_time(requester.user) + yield self.receipts_handler.received_client_receipt( room_id, receipt_type, diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index accbc6cfac..de4a020ad4 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -25,6 +25,7 @@ from synapse.events.utils import ( ) from synapse.api.filtering import FilterCollection, DEFAULT_FILTER_COLLECTION from synapse.api.errors import SynapseError +from synapse.api.constants import PresenceState from ._base import client_v2_patterns import copy @@ -82,6 +83,7 @@ class SyncRestServlet(RestServlet): self.sync_handler = hs.get_handlers().sync_handler self.clock = hs.get_clock() self.filtering = hs.get_filtering() + self.presence_handler = hs.get_handlers().presence_handler @defer.inlineCallbacks def on_GET(self, request): @@ -139,17 +141,19 @@ class SyncRestServlet(RestServlet): else: since_token = None - if set_presence == "online": - yield self.event_stream_handler.started_stream(user) + affect_presence = set_presence != PresenceState.OFFLINE - try: + if affect_presence: + yield self.presence_handler.set_state(user, {"presence": set_presence}) + + context = yield self.presence_handler.user_syncing( + user.to_string(), affect_presence=affect_presence, + ) + with context: sync_result = yield self.sync_handler.wait_for_sync_for_user( sync_config, since_token=since_token, timeout=timeout, full_state=full_state ) - finally: - if set_presence == "online": - self.event_stream_handler.stopped_stream(user) time_now = self.clock.time_msec() diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py index dcf3eaee1f..d9fc045fc6 100644 --- a/synapse/rest/media/v0/content_repository.py +++ b/synapse/rest/media/v0/content_repository.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.http.server import respond_with_json_bytes +from synapse.http.server import respond_with_json_bytes, finish_request from synapse.util.stringutils import random_string from synapse.api.errors import ( @@ -144,7 +144,7 @@ class ContentRepoResource(resource.Resource): # after the file has been sent, clean up and finish the request def cbFinished(ignored): f.close() - request.finish() + finish_request(request) d.addCallback(cbFinished) else: respond_with_json_bytes( diff --git a/synapse/rest/media/v1/base_resource.py b/synapse/rest/media/v1/base_resource.py index 58d56ec7a4..58ef91c0b8 100644 --- a/synapse/rest/media/v1/base_resource.py +++ b/synapse/rest/media/v1/base_resource.py @@ -16,7 +16,7 @@ from .thumbnailer import Thumbnailer from synapse.http.matrixfederationclient import MatrixFederationHttpClient -from synapse.http.server import respond_with_json +from synapse.http.server import respond_with_json, finish_request from synapse.util.stringutils import random_string from synapse.api.errors import ( cs_error, Codes, SynapseError @@ -238,7 +238,7 @@ class BaseMediaResource(Resource): with open(file_path, "rb") as f: yield FileSender().beginFileTransfer(f, request) - request.finish() + finish_request(request) else: self._respond_404(request) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 5a9e7720d9..f257721ea3 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -20,7 +20,7 @@ from .appservice import ( from ._base import Cache from .directory import DirectoryStore from .events import EventsStore -from .presence import PresenceStore +from .presence import PresenceStore, UserPresenceState from .profile import ProfileStore from .registration import RegistrationStore from .room import RoomStore @@ -47,6 +47,7 @@ from .account_data import AccountDataStore from util.id_generators import IdGenerator, StreamIdGenerator +from synapse.api.constants import PresenceState from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -110,16 +111,19 @@ class DataStore(RoomMemberStore, RoomStore, self._account_data_id_gen = StreamIdGenerator( db_conn, "account_data_max_stream_id", "stream_id" ) + self._presence_id_gen = StreamIdGenerator( + db_conn, "presence_stream", "stream_id" + ) - self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) - self._state_groups_id_gen = IdGenerator("state_groups", "id", self) - self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) - self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self) - self._pushers_id_gen = IdGenerator("pushers", "id", self) - self._push_rule_id_gen = IdGenerator("push_rules", "id", self) - self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) + self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") + self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") + self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") + self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") + self._pushers_id_gen = IdGenerator(db_conn, "pushers", "id") + self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") + self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") - events_max = self._stream_id_gen.get_max_token(None) + events_max = self._stream_id_gen.get_max_token() event_cache_prefill, min_event_val = self._get_cache_dict( db_conn, "events", entity_column="room_id", @@ -135,13 +139,31 @@ class DataStore(RoomMemberStore, RoomStore, "MembershipStreamChangeCache", events_max, ) - account_max = self._account_data_id_gen.get_max_token(None) + account_max = self._account_data_id_gen.get_max_token() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max, ) + self.__presence_on_startup = self._get_active_presence(db_conn) + + presence_cache_prefill, min_presence_val = self._get_cache_dict( + db_conn, "presence_stream", + entity_column="user_id", + stream_column="stream_id", + max_value=self._presence_id_gen.get_max_token(), + ) + self.presence_stream_cache = StreamChangeCache( + "PresenceStreamChangeCache", min_presence_val, + prefilled_cache=presence_cache_prefill + ) + super(DataStore, self).__init__(hs) + def take_presence_startup_info(self): + active_on_startup = self.__presence_on_startup + self.__presence_on_startup = None + return active_on_startup + def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value): # Fetch a mapping of room_id -> max stream position for "recent" rooms. # It doesn't really matter how many we get, the StreamChangeCache will @@ -161,6 +183,7 @@ class DataStore(RoomMemberStore, RoomStore, txn = db_conn.cursor() txn.execute(sql, (int(max_value),)) rows = txn.fetchall() + txn.close() cache = { row[0]: int(row[1]) @@ -174,6 +197,28 @@ class DataStore(RoomMemberStore, RoomStore, return cache, min_val + def _get_active_presence(self, db_conn): + """Fetch non-offline presence from the database so that we can register + the appropriate time outs. + """ + + sql = ( + "SELECT user_id, state, last_active_ts, last_federation_update_ts," + " last_user_sync_ts, status_msg, currently_active FROM presence_stream" + " WHERE state != ?" + ) + sql = self.database_engine.convert_param_style(sql) + + txn = db_conn.cursor() + txn.execute(sql, (PresenceState.OFFLINE,)) + rows = self.cursor_to_dict(txn) + txn.close() + + for row in rows: + row["currently_active"] = bool(row["currently_active"]) + + return [UserPresenceState(**row) for row in rows] + @defer.inlineCallbacks def insert_client_ip(self, user, access_token, ip, user_agent): now = int(self._clock.time_msec()) diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index b8387fc500..faddefe219 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -83,8 +83,40 @@ class AccountDataStore(SQLBaseStore): "get_account_data_for_room", get_account_data_for_room_txn ) - def get_updated_account_data_for_user(self, user_id, stream_id, room_ids=None): - """Get all the client account_data for a that's changed. + def get_all_updated_account_data(self, last_global_id, last_room_id, + current_id, limit): + """Get all the client account_data that has changed on the server + Args: + last_global_id(int): The position to fetch from for top level data + last_room_id(int): The position to fetch from for per room data + current_id(int): The position to fetch up to. + Returns: + A deferred pair of lists of tuples of stream_id int, user_id string, + room_id string, type string, and content string. + """ + def get_updated_account_data_txn(txn): + sql = ( + "SELECT stream_id, user_id, account_data_type, content" + " FROM account_data WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_global_id, current_id, limit)) + global_results = txn.fetchall() + + sql = ( + "SELECT stream_id, user_id, room_id, account_data_type, content" + " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_room_id, current_id, limit)) + room_results = txn.fetchall() + return (global_results, room_results) + return self.runInteraction( + "get_all_updated_account_data_txn", get_updated_account_data_txn + ) + + def get_updated_account_data_for_user(self, user_id, stream_id): + """Get all the client account_data for a that's changed for a user Args: user_id(str): The user to get the account_data for. @@ -163,12 +195,12 @@ class AccountDataStore(SQLBaseStore): ) self._update_max_stream_id(txn, next_id) - with (yield self._account_data_id_gen.get_next(self)) as next_id: + with self._account_data_id_gen.get_next() as next_id: yield self.runInteraction( "add_room_account_data", add_account_data_txn, next_id ) - result = yield self._account_data_id_gen.get_max_token(self) + result = self._account_data_id_gen.get_max_token() defer.returnValue(result) @defer.inlineCallbacks @@ -202,12 +234,12 @@ class AccountDataStore(SQLBaseStore): ) self._update_max_stream_id(txn, next_id) - with (yield self._account_data_id_gen.get_next(self)) as next_id: + with self._account_data_id_gen.get_next() as next_id: yield self.runInteraction( "add_user_account_data", add_account_data_txn, next_id ) - result = yield self._account_data_id_gen.get_max_token(self) + result = self._account_data_id_gen.get_max_token() defer.returnValue(result) def _update_max_stream_id(self, txn, next_id): diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index ce2c794025..3489315e0d 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -114,10 +114,10 @@ class EventFederationStore(SQLBaseStore): retcol="event_id", ) - def get_latest_events_in_room(self, room_id): + def get_latest_event_ids_and_hashes_in_room(self, room_id): return self.runInteraction( - "get_latest_events_in_room", - self._get_latest_events_in_room, + "get_latest_event_ids_and_hashes_in_room", + self._get_latest_event_ids_and_hashes_in_room, room_id, ) @@ -132,7 +132,7 @@ class EventFederationStore(SQLBaseStore): desc="get_latest_event_ids_in_room", ) - def _get_latest_events_in_room(self, txn, room_id): + def _get_latest_event_ids_and_hashes_in_room(self, txn, room_id): sql = ( "SELECT e.event_id, e.depth FROM events as e " "INNER JOIN event_forward_extremities as f " diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index d77a817682..5820539a92 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -27,15 +27,14 @@ class EventPushActionsStore(SQLBaseStore): def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples): """ :param event: the event set actions for - :param tuples: list of tuples of (user_id, profile_tag, actions) + :param tuples: list of tuples of (user_id, actions) """ values = [] - for uid, profile_tag, actions in tuples: + for uid, actions in tuples: values.append({ 'room_id': event.room_id, 'event_id': event.event_id, 'user_id': uid, - 'profile_tag': profile_tag, 'actions': json.dumps(actions), 'stream_ordering': event.internal_metadata.stream_ordering, 'topological_ordering': event.depth, @@ -43,7 +42,7 @@ class EventPushActionsStore(SQLBaseStore): 'highlight': 1 if _action_has_highlight(actions) else 0, }) - for uid, _, __ in tuples: + for uid, __ in tuples: txn.call_after( self.get_unread_event_push_actions_by_room_for_user.invalidate_many, (event.room_id, uid) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 3a5c6ee4b1..60936500d8 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -75,8 +75,8 @@ class EventsStore(SQLBaseStore): yield stream_orderings stream_ordering_manager = stream_ordering_manager() else: - stream_ordering_manager = yield self._stream_id_gen.get_next_mult( - self, len(events_and_contexts) + stream_ordering_manager = self._stream_id_gen.get_next_mult( + len(events_and_contexts) ) with stream_ordering_manager as stream_orderings: @@ -109,7 +109,7 @@ class EventsStore(SQLBaseStore): stream_ordering = self.min_stream_token if stream_ordering is None: - stream_ordering_manager = yield self._stream_id_gen.get_next(self) + stream_ordering_manager = self._stream_id_gen.get_next() else: @contextmanager def stream_ordering_manager(): @@ -131,7 +131,7 @@ class EventsStore(SQLBaseStore): except _RollbackButIsFineException: pass - max_persisted_id = yield self._stream_id_gen.get_max_token(self) + max_persisted_id = yield self._stream_id_gen.get_max_token() defer.returnValue((stream_ordering, max_persisted_id)) @defer.inlineCallbacks @@ -1064,3 +1064,48 @@ class EventsStore(SQLBaseStore): yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME) defer.returnValue(result) + + def get_current_backfill_token(self): + """The current minimum token that backfilled events have reached""" + + # TODO: Fix race with the persit_event txn by using one of the + # stream id managers + return -self.min_stream_token + + def get_all_new_events(self, last_backfill_id, last_forward_id, + current_backfill_id, current_forward_id, limit): + """Get all the new events that have arrived at the server either as + new events or as backfilled events""" + def get_all_new_events_txn(txn): + sql = ( + "SELECT e.stream_ordering, ej.internal_metadata, ej.json" + " FROM events as e" + " JOIN event_json as ej" + " ON e.event_id = ej.event_id AND e.room_id = ej.room_id" + " WHERE ? < e.stream_ordering AND e.stream_ordering <= ?" + " ORDER BY e.stream_ordering ASC" + " LIMIT ?" + ) + if last_forward_id != current_forward_id: + txn.execute(sql, (last_forward_id, current_forward_id, limit)) + new_forward_events = txn.fetchall() + else: + new_forward_events = [] + + sql = ( + "SELECT -e.stream_ordering, ej.internal_metadata, ej.json" + " FROM events as e" + " JOIN event_json as ej" + " ON e.event_id = ej.event_id AND e.room_id = ej.room_id" + " WHERE ? > e.stream_ordering AND e.stream_ordering >= ?" + " ORDER BY e.stream_ordering DESC" + " LIMIT ?" + ) + if last_backfill_id != current_backfill_id: + txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit)) + new_backfill_events = txn.fetchall() + else: + new_backfill_events = [] + + return (new_forward_events, new_backfill_events) + return self.runInteraction("get_all_new_events", get_all_new_events_txn) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 850736c85e..0fd5d497ab 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 29 +SCHEMA_VERSION = 30 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index ef525f34c5..4cec31e316 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -14,73 +14,148 @@ # limitations under the License. from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached, cachedList +from synapse.api.constants import PresenceState +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from collections import namedtuple from twisted.internet import defer +class UserPresenceState(namedtuple("UserPresenceState", + ("user_id", "state", "last_active_ts", + "last_federation_update_ts", "last_user_sync_ts", + "status_msg", "currently_active"))): + """Represents the current presence state of the user. + + user_id (str) + last_active (int): Time in msec that the user last interacted with server. + last_federation_update (int): Time in msec since either a) we sent a presence + update to other servers or b) we received a presence update, depending + on if is a local user or not. + last_user_sync (int): Time in msec that the user last *completed* a sync + (or event stream). + status_msg (str): User set status message. + """ + + def copy_and_replace(self, **kwargs): + return self._replace(**kwargs) + + @classmethod + def default(cls, user_id): + """Returns a default presence state. + """ + return cls( + user_id=user_id, + state=PresenceState.OFFLINE, + last_active_ts=0, + last_federation_update_ts=0, + last_user_sync_ts=0, + status_msg=None, + currently_active=False, + ) + + class PresenceStore(SQLBaseStore): - def create_presence(self, user_localpart): - res = self._simple_insert( - table="presence", - values={"user_id": user_localpart}, - desc="create_presence", + @defer.inlineCallbacks + def update_presence(self, presence_states): + stream_ordering_manager = self._presence_id_gen.get_next_mult( + len(presence_states) ) - self.get_presence_state.invalidate((user_localpart,)) - return res + with stream_ordering_manager as stream_orderings: + yield self.runInteraction( + "update_presence", + self._update_presence_txn, stream_orderings, presence_states, + ) - def has_presence_state(self, user_localpart): - return self._simple_select_one( - table="presence", - keyvalues={"user_id": user_localpart}, - retcols=["user_id"], - allow_none=True, - desc="has_presence_state", + defer.returnValue((stream_orderings[-1], self._presence_id_gen.get_max_token())) + + def _update_presence_txn(self, txn, stream_orderings, presence_states): + for stream_id, state in zip(stream_orderings, presence_states): + txn.call_after( + self.presence_stream_cache.entity_has_changed, + state.user_id, stream_id, + ) + + # Actually insert new rows + self._simple_insert_many_txn( + txn, + table="presence_stream", + values=[ + { + "stream_id": stream_id, + "user_id": state.user_id, + "state": state.state, + "last_active_ts": state.last_active_ts, + "last_federation_update_ts": state.last_federation_update_ts, + "last_user_sync_ts": state.last_user_sync_ts, + "status_msg": state.status_msg, + "currently_active": state.currently_active, + } + for state in presence_states + ], ) - @cached(max_entries=2000) - def get_presence_state(self, user_localpart): - return self._simple_select_one( - table="presence", - keyvalues={"user_id": user_localpart}, - retcols=["state", "status_msg", "mtime"], - desc="get_presence_state", + # Delete old rows to stop database from getting really big + sql = ( + "DELETE FROM presence_stream WHERE" + " stream_id < ?" + " AND user_id IN (%s)" ) - @cachedList(get_presence_state.cache, list_name="user_localparts", - inlineCallbacks=True) - def get_presence_states(self, user_localparts): - rows = yield self._simple_select_many_batch( - table="presence", - column="user_id", - iterable=user_localparts, - retcols=("user_id", "state", "status_msg", "mtime",), - desc="get_presence_states", + batches = ( + presence_states[i:i + 50] + for i in xrange(0, len(presence_states), 50) ) + for states in batches: + args = [stream_id] + args.extend(s.user_id for s in states) + txn.execute( + sql % (",".join("?" for _ in states),), + args + ) + + def get_all_presence_updates(self, last_id, current_id): + def get_all_presence_updates_txn(txn): + sql = ( + "SELECT stream_id, user_id, state, last_active_ts," + " last_federation_update_ts, last_user_sync_ts, status_msg," + " currently_active" + " FROM presence_stream" + " WHERE ? < stream_id AND stream_id <= ?" + ) + txn.execute(sql, (last_id, current_id)) + return txn.fetchall() - defer.returnValue({ - row["user_id"]: { - "state": row["state"], - "status_msg": row["status_msg"], - "mtime": row["mtime"], - } - for row in rows - }) + return self.runInteraction( + "get_all_presence_updates", get_all_presence_updates_txn + ) @defer.inlineCallbacks - def set_presence_state(self, user_localpart, new_state): - res = yield self._simple_update_one( - table="presence", - keyvalues={"user_id": user_localpart}, - updatevalues={"state": new_state["state"], - "status_msg": new_state["status_msg"], - "mtime": self._clock.time_msec()}, - desc="set_presence_state", + def get_presence_for_users(self, user_ids): + rows = yield self._simple_select_many_batch( + table="presence_stream", + column="user_id", + iterable=user_ids, + keyvalues={}, + retcols=( + "user_id", + "state", + "last_active_ts", + "last_federation_update_ts", + "last_user_sync_ts", + "status_msg", + "currently_active", + ), ) - self.get_presence_state.invalidate((user_localpart,)) - defer.returnValue(res) + for row in rows: + row["currently_active"] = bool(row["currently_active"]) + + defer.returnValue([UserPresenceState(**row) for row in rows]) + + def get_current_presence_token(self): + return self._presence_id_gen.get_max_token() def allow_presence_visible(self, observed_localpart, observer_userid): return self._simple_insert( @@ -128,6 +203,7 @@ class PresenceStore(SQLBaseStore): desc="set_presence_list_accepted", ) self.get_presence_list_accepted.invalidate((observer_localpart,)) + self.get_presence_list_observers_accepted.invalidate((observed_userid,)) defer.returnValue(result) def get_presence_list(self, observer_localpart, accepted=None): @@ -154,6 +230,19 @@ class PresenceStore(SQLBaseStore): desc="get_presence_list_accepted", ) + @cachedInlineCallbacks() + def get_presence_list_observers_accepted(self, observed_userid): + user_localparts = yield self._simple_select_onecol( + table="presence_list", + keyvalues={"observed_user_id": observed_userid, "accepted": True}, + retcol="user_id", + desc="get_presence_list_accepted", + ) + + defer.returnValue([ + "@%s:%s" % (u, self.hs.hostname,) for u in user_localparts + ]) + @defer.inlineCallbacks def del_presence_list(self, observer_localpart, observed_userid): yield self._simple_delete_one( @@ -163,3 +252,4 @@ class PresenceStore(SQLBaseStore): desc="del_presence_list", ) self.get_presence_list_accepted.invalidate((observer_localpart,)) + self.get_presence_list_observers_accepted.invalidate((observed_userid,)) diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index f9a48171ba..56e69495b1 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -99,38 +99,36 @@ class PushRuleStore(SQLBaseStore): results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled'] defer.returnValue(results) - @defer.inlineCallbacks - def add_push_rule(self, before, after, **kwargs): - vals = kwargs - if 'conditions' in vals: - vals['conditions'] = json.dumps(vals['conditions']) - if 'actions' in vals: - vals['actions'] = json.dumps(vals['actions']) - - # we could check the rest of the keys are valid column names - # but sqlite will do that anyway so I think it's just pointless. - vals.pop("id", None) + def add_push_rule( + self, user_id, rule_id, priority_class, conditions, actions, + before=None, after=None + ): + conditions_json = json.dumps(conditions) + actions_json = json.dumps(actions) if before or after: - ret = yield self.runInteraction( + return self.runInteraction( "_add_push_rule_relative_txn", self._add_push_rule_relative_txn, - before=before, - after=after, - **vals + user_id, rule_id, priority_class, + conditions_json, actions_json, before, after, ) - defer.returnValue(ret) else: - ret = yield self.runInteraction( + return self.runInteraction( "_add_push_rule_highest_priority_txn", self._add_push_rule_highest_priority_txn, - **vals + user_id, rule_id, priority_class, + conditions_json, actions_json, ) - defer.returnValue(ret) - def _add_push_rule_relative_txn(self, txn, user_id, **kwargs): - after = kwargs.pop("after", None) - before = kwargs.pop("before", None) + def _add_push_rule_relative_txn( + self, txn, user_id, rule_id, priority_class, + conditions_json, actions_json, before, after + ): + # Lock the table since otherwise we'll have annoying races between the + # SELECT here and the UPSERT below. + self.database_engine.lock_table(txn, "push_rules") + relative_to_rule = before or after res = self._simple_select_one_txn( @@ -149,69 +147,45 @@ class PushRuleStore(SQLBaseStore): "before/after rule not found: %s" % (relative_to_rule,) ) - priority_class = res["priority_class"] + base_priority_class = res["priority_class"] base_rule_priority = res["priority"] - if 'priority_class' in kwargs and kwargs['priority_class'] != priority_class: + if base_priority_class != priority_class: raise InconsistentRuleException( "Given priority class does not match class of relative rule" ) - new_rule = kwargs - new_rule.pop("before", None) - new_rule.pop("after", None) - new_rule['priority_class'] = priority_class - new_rule['user_name'] = user_id - new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn) - - # check if the priority before/after is free - new_rule_priority = base_rule_priority - if after: - new_rule_priority -= 1 + if before: + # Higher priority rules are executed first, So adding a rule before + # a rule means giving it a higher priority than that rule. + new_rule_priority = base_rule_priority + 1 else: - new_rule_priority += 1 - - new_rule['priority'] = new_rule_priority + # We increment the priority of the existing rules to make space for + # the new rule. Therefore if we want this rule to appear after + # an existing rule we give it the priority of the existing rule, + # and then increment the priority of the existing rule. + new_rule_priority = base_rule_priority sql = ( - "SELECT COUNT(*) FROM push_rules" - " WHERE user_name = ? AND priority_class = ? AND priority = ?" + "UPDATE push_rules SET priority = priority + 1" + " WHERE user_name = ? AND priority_class = ? AND priority >= ?" ) - txn.execute(sql, (user_id, priority_class, new_rule_priority)) - res = txn.fetchall() - num_conflicting = res[0][0] - - # if there are conflicting rules, bump everything - if num_conflicting: - sql = "UPDATE push_rules SET priority = priority " - if after: - sql += "-1" - else: - sql += "+1" - sql += " WHERE user_name = ? AND priority_class = ? AND priority " - if after: - sql += "<= ?" - else: - sql += ">= ?" - txn.execute(sql, (user_id, priority_class, new_rule_priority)) - - txn.call_after( - self.get_push_rules_for_user.invalidate, (user_id,) - ) + txn.execute(sql, (user_id, priority_class, new_rule_priority)) - txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, (user_id,) + self._upsert_push_rule_txn( + txn, user_id, rule_id, priority_class, new_rule_priority, + conditions_json, actions_json, ) - self._simple_insert_txn( - txn, - table="push_rules", - values=new_rule, - ) + def _add_push_rule_highest_priority_txn( + self, txn, user_id, rule_id, priority_class, + conditions_json, actions_json + ): + # Lock the table since otherwise we'll have annoying races between the + # SELECT here and the UPSERT below. + self.database_engine.lock_table(txn, "push_rules") - def _add_push_rule_highest_priority_txn(self, txn, user_id, - priority_class, **kwargs): # find the highest priority rule in that class sql = ( "SELECT COUNT(*), MAX(priority) FROM push_rules" @@ -225,12 +199,48 @@ class PushRuleStore(SQLBaseStore): if how_many > 0: new_prio = highest_prio + 1 - # and insert the new rule - new_rule = kwargs - new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn) - new_rule['user_name'] = user_id - new_rule['priority_class'] = priority_class - new_rule['priority'] = new_prio + self._upsert_push_rule_txn( + txn, + user_id, rule_id, priority_class, new_prio, + conditions_json, actions_json, + ) + + def _upsert_push_rule_txn( + self, txn, user_id, rule_id, priority_class, + priority, conditions_json, actions_json + ): + """Specialised version of _simple_upsert_txn that picks a push_rule_id + using the _push_rule_id_gen if it needs to insert the rule. It assumes + that the "push_rules" table is locked""" + + sql = ( + "UPDATE push_rules" + " SET priority_class = ?, priority = ?, conditions = ?, actions = ?" + " WHERE user_name = ? AND rule_id = ?" + ) + + txn.execute(sql, ( + priority_class, priority, conditions_json, actions_json, + user_id, rule_id, + )) + + if txn.rowcount == 0: + # We didn't update a row with the given rule_id so insert one + push_rule_id = self._push_rule_id_gen.get_next() + + self._simple_insert_txn( + txn, + table="push_rules", + values={ + "id": push_rule_id, + "user_name": user_id, + "rule_id": rule_id, + "priority_class": priority_class, + "priority": priority, + "conditions": conditions_json, + "actions": actions_json, + }, + ) txn.call_after( self.get_push_rules_for_user.invalidate, (user_id,) @@ -239,12 +249,6 @@ class PushRuleStore(SQLBaseStore): self.get_push_rules_enabled_for_user.invalidate, (user_id,) ) - self._simple_insert_txn( - txn, - table="push_rules", - values=new_rule, - ) - @defer.inlineCallbacks def delete_push_rule(self, user_id, rule_id): """ @@ -275,7 +279,7 @@ class PushRuleStore(SQLBaseStore): defer.returnValue(ret) def _set_push_rule_enabled_txn(self, txn, user_id, rule_id, enabled): - new_id = self._push_rules_enable_id_gen.get_next_txn(txn) + new_id = self._push_rules_enable_id_gen.get_next() self._simple_upsert_txn( txn, "push_rules_enable", @@ -290,6 +294,31 @@ class PushRuleStore(SQLBaseStore): self.get_push_rules_enabled_for_user.invalidate, (user_id,) ) + def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule): + actions_json = json.dumps(actions) + + def set_push_rule_actions_txn(txn): + if is_default_rule: + # Add a dummy rule to the rules table with the user specified + # actions. + priority_class = -1 + priority = 1 + self._upsert_push_rule_txn( + txn, user_id, rule_id, priority_class, priority, + "[]", actions_json + ) + else: + self._simple_update_one_txn( + txn, + "push_rules", + {'user_name': user_id, 'rule_id': rule_id}, + {'actions': actions_json}, + ) + + return self.runInteraction( + "set_push_rule_actions", set_push_rule_actions_txn, + ) + class RuleNotFoundException(Exception): pass diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 8ec706178a..7693ab9082 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -80,11 +80,11 @@ class PusherStore(SQLBaseStore): defer.returnValue(rows) @defer.inlineCallbacks - def add_pusher(self, user_id, access_token, profile_tag, kind, app_id, + def add_pusher(self, user_id, access_token, kind, app_id, app_display_name, device_display_name, - pushkey, pushkey_ts, lang, data): + pushkey, pushkey_ts, lang, data, profile_tag=""): try: - next_id = yield self._pushers_id_gen.get_next() + next_id = self._pushers_id_gen.get_next() yield self._simple_upsert( "pushers", dict( @@ -95,12 +95,12 @@ class PusherStore(SQLBaseStore): dict( access_token=access_token, kind=kind, - profile_tag=profile_tag, app_display_name=app_display_name, device_display_name=device_display_name, ts=pushkey_ts, lang=lang, data=encode_canonical_json(data), + profile_tag=profile_tag, ), insertion_values=dict( id=next_id, diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 4202a6b3dc..dbc074d6b5 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -31,7 +31,7 @@ class ReceiptsStore(SQLBaseStore): super(ReceiptsStore, self).__init__(hs) self._receipts_stream_cache = StreamChangeCache( - "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token(None) + "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token() ) @cached(num_args=2) @@ -222,7 +222,7 @@ class ReceiptsStore(SQLBaseStore): defer.returnValue(results) def get_max_receipt_stream_id(self): - return self._receipts_id_gen.get_max_token(self) + return self._receipts_id_gen.get_max_token() def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, user_id, event_id, data, stream_id): @@ -330,7 +330,7 @@ class ReceiptsStore(SQLBaseStore): "insert_receipt_conv", graph_to_linear ) - stream_id_manager = yield self._receipts_id_gen.get_next(self) + stream_id_manager = self._receipts_id_gen.get_next() with stream_id_manager as stream_id: have_persisted = yield self.runInteraction( "insert_linearized_receipt", @@ -347,7 +347,7 @@ class ReceiptsStore(SQLBaseStore): room_id, receipt_type, user_id, event_ids, data ) - max_persisted_id = yield self._stream_id_gen.get_max_token(self) + max_persisted_id = self._stream_id_gen.get_max_token() defer.returnValue((stream_id, max_persisted_id)) @@ -390,3 +390,19 @@ class ReceiptsStore(SQLBaseStore): "data": json.dumps(data), } ) + + def get_all_updated_receipts(self, last_id, current_id, limit): + def get_all_updated_receipts_txn(txn): + sql = ( + "SELECT stream_id, room_id, receipt_type, user_id, event_id, data" + " FROM receipts_linearized" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + " LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + + return txn.fetchall() + return self.runInteraction( + "get_all_updated_receipts", get_all_updated_receipts_txn + ) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 967c732bda..ad1157f979 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -40,7 +40,7 @@ class RegistrationStore(SQLBaseStore): Raises: StoreError if there was a problem adding this. """ - next_id = yield self._access_tokens_id_gen.get_next() + next_id = self._access_tokens_id_gen.get_next() yield self._simple_insert( "access_tokens", @@ -62,7 +62,7 @@ class RegistrationStore(SQLBaseStore): Raises: StoreError if there was a problem adding this. """ - next_id = yield self._refresh_tokens_id_gen.get_next() + next_id = self._refresh_tokens_id_gen.get_next() yield self._simple_insert( "refresh_tokens", @@ -99,7 +99,7 @@ class RegistrationStore(SQLBaseStore): def _register(self, txn, user_id, token, password_hash, was_guest, make_guest): now = int(self.clock.time()) - next_id = self._access_tokens_id_gen.get_next_txn(txn) + next_id = self._access_tokens_id_gen.get_next() try: if was_guest: @@ -387,3 +387,47 @@ class RegistrationStore(SQLBaseStore): "find_next_generated_user_id", _find_next_generated_user_id ))) + + @defer.inlineCallbacks + def get_3pid_guest_access_token(self, medium, address): + ret = yield self._simple_select_one( + "threepid_guest_access_tokens", + { + "medium": medium, + "address": address + }, + ["guest_access_token"], True, 'get_3pid_guest_access_token' + ) + if ret: + defer.returnValue(ret["guest_access_token"]) + defer.returnValue(None) + + @defer.inlineCallbacks + def save_or_get_3pid_guest_access_token( + self, medium, address, access_token, inviter_user_id + ): + """ + Gets the 3pid's guest access token if exists, else saves access_token. + + :param medium (str): Medium of the 3pid. Must be "email". + :param address (str): 3pid address. + :param access_token (str): The access token to persist if none is + already persisted. + :param inviter_user_id (str): User ID of the inviter. + :return (deferred str): Whichever access token is persisted at the end + of this function call. + """ + def insert(txn): + txn.execute( + "INSERT INTO threepid_guest_access_tokens " + "(medium, address, guest_access_token, first_inviter) " + "VALUES (?, ?, ?, ?)", + (medium, address, access_token, inviter_user_id) + ) + + try: + yield self.runInteraction("save_3pid_guest_access_token", insert) + defer.returnValue(access_token) + except self.database_engine.module.IntegrityError: + ret = yield self.get_3pid_guest_access_token(medium, address) + defer.returnValue(ret) diff --git a/synapse/storage/schema/delta/30/presence_stream.sql b/synapse/storage/schema/delta/30/presence_stream.sql new file mode 100644 index 0000000000..606bbb037d --- /dev/null +++ b/synapse/storage/schema/delta/30/presence_stream.sql @@ -0,0 +1,30 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + CREATE TABLE presence_stream( + stream_id BIGINT, + user_id TEXT, + state TEXT, + last_active_ts BIGINT, + last_federation_update_ts BIGINT, + last_user_sync_ts BIGINT, + status_msg TEXT, + currently_active BOOLEAN + ); + + CREATE INDEX presence_stream_id ON presence_stream(stream_id, user_id); + CREATE INDEX presence_stream_user_id ON presence_stream(user_id); + CREATE INDEX presence_stream_state ON presence_stream(state); diff --git a/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql b/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql new file mode 100644 index 0000000000..0dd2f1360c --- /dev/null +++ b/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql @@ -0,0 +1,24 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Stores guest account access tokens generated for unbound 3pids. +CREATE TABLE threepid_guest_access_tokens( + medium TEXT, -- The medium of the 3pid. Must be "email". + address TEXT, -- The 3pid address. + guest_access_token TEXT, -- The access token for a guest user for this 3pid. + first_inviter TEXT -- User ID of the first user to invite this 3pid to a room. +); + +CREATE UNIQUE INDEX threepid_guest_access_tokens_index ON threepid_guest_access_tokens(medium, address); diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 372b540002..8ed8a21b0a 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -83,7 +83,7 @@ class StateStore(SQLBaseStore): if event.is_state(): state_events[(event.type, event.state_key)] = event - state_group = self._state_groups_id_gen.get_next_txn(txn) + state_group = self._state_groups_id_gen.get_next() self._simple_insert_txn( txn, table="state_groups", diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index c236dafafb..8908d5b5da 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -531,7 +531,7 @@ class StreamStore(SQLBaseStore): @defer.inlineCallbacks def get_room_events_max_id(self, direction='f'): - token = yield self._stream_id_gen.get_max_token(self) + token = yield self._stream_id_gen.get_max_token() if direction != 'b': defer.returnValue("s%d" % (token,)) else: diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index e1a9c0c261..a0e6b42b30 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -30,7 +30,7 @@ class TagsStore(SQLBaseStore): Returns: A deferred int. """ - return self._account_data_id_gen.get_max_token(self) + return self._account_data_id_gen.get_max_token() @cached() def get_tags_for_user(self, user_id): @@ -59,6 +59,59 @@ class TagsStore(SQLBaseStore): return deferred @defer.inlineCallbacks + def get_all_updated_tags(self, last_id, current_id, limit): + """Get all the client tags that have changed on the server + Args: + last_id(int): The position to fetch from. + current_id(int): The position to fetch up to. + Returns: + A deferred list of tuples of stream_id int, user_id string, + room_id string, tag string and content string. + """ + def get_all_updated_tags_txn(txn): + sql = ( + "SELECT stream_id, user_id, room_id" + " FROM room_tags_revisions as r" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() + + tag_ids = yield self.runInteraction( + "get_all_updated_tags", get_all_updated_tags_txn + ) + + def get_tag_content(txn, tag_ids): + sql = ( + "SELECT tag, content" + " FROM room_tags" + " WHERE user_id=? AND room_id=?" + ) + results = [] + for stream_id, user_id, room_id in tag_ids: + txn.execute(sql, (user_id, room_id)) + tags = [] + for tag, content in txn.fetchall(): + tags.append(json.dumps(tag) + ":" + content) + tag_json = "{" + ",".join(tags) + "}" + results.append((stream_id, user_id, room_id, tag_json)) + + return results + + batch_size = 50 + results = [] + for i in xrange(0, len(tag_ids), batch_size): + tags = yield self.runInteraction( + "get_all_updated_tag_content", + get_tag_content, + tag_ids[i:i + batch_size], + ) + results.extend(tags) + + defer.returnValue(results) + + @defer.inlineCallbacks def get_updated_tags(self, user_id, stream_id): """Get all the tags for the rooms where the tags have changed since the given version @@ -142,12 +195,12 @@ class TagsStore(SQLBaseStore): ) self._update_revision_txn(txn, user_id, room_id, next_id) - with (yield self._account_data_id_gen.get_next(self)) as next_id: + with self._account_data_id_gen.get_next() as next_id: yield self.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) - result = yield self._account_data_id_gen.get_max_token(self) + result = self._account_data_id_gen.get_max_token() defer.returnValue(result) @defer.inlineCallbacks @@ -164,12 +217,12 @@ class TagsStore(SQLBaseStore): txn.execute(sql, (user_id, room_id, tag)) self._update_revision_txn(txn, user_id, room_id, next_id) - with (yield self._account_data_id_gen.get_next(self)) as next_id: + with self._account_data_id_gen.get_next() as next_id: yield self.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) - result = yield self._account_data_id_gen.get_max_token(self) + result = self._account_data_id_gen.get_max_token() defer.returnValue(result) def _update_revision_txn(self, txn, user_id, room_id, next_id): diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 4475c451c1..d338dfcf0a 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -117,7 +117,7 @@ class TransactionStore(SQLBaseStore): def _prep_send_transaction(self, txn, transaction_id, destination, origin_server_ts): - next_id = self._transaction_id_gen.get_next_txn(txn) + next_id = self._transaction_id_gen.get_next() # First we find out what the prev_txns should be. # Since we know that we are only sending one transaction at a time, diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 5c522f4ab9..efe3f68e6e 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -13,51 +13,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from collections import deque import contextlib import threading class IdGenerator(object): - def __init__(self, table, column, store): + def __init__(self, db_conn, table, column): self.table = table self.column = column - self.store = store self._lock = threading.Lock() - self._next_id = None + cur = db_conn.cursor() + self._next_id = self._load_next_id(cur) + cur.close() - @defer.inlineCallbacks - def get_next(self): - if self._next_id is None: - yield self.store.runInteraction( - "IdGenerator_%s" % (self.table,), - self.get_next_txn, - ) + def _load_next_id(self, txn): + txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table,)) + val, = txn.fetchone() + return val + 1 if val else 1 + def get_next(self): with self._lock: i = self._next_id self._next_id += 1 - defer.returnValue(i) - - def get_next_txn(self, txn): - with self._lock: - if self._next_id: - i = self._next_id - self._next_id += 1 - return i - else: - txn.execute( - "SELECT MAX(%s) FROM %s" % (self.column, self.table,) - ) - - val, = txn.fetchone() - cur = val or 0 - cur += 1 - self._next_id = cur + 1 - - return cur + return i class StreamIdGenerator(object): @@ -69,7 +48,7 @@ class StreamIdGenerator(object): persistence of events can complete out of order. Usage: - with stream_id_gen.get_next_txn(txn) as stream_id: + with stream_id_gen.get_next() as stream_id: # ... persist event ... """ def __init__(self, db_conn, table, column): @@ -79,15 +58,21 @@ class StreamIdGenerator(object): self._lock = threading.Lock() cur = db_conn.cursor() - self._current_max = self._get_or_compute_current_max(cur) + self._current_max = self._load_current_max(cur) cur.close() self._unfinished_ids = deque() - def get_next(self, store): + def _load_current_max(self, txn): + txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table)) + rows = txn.fetchall() + val, = rows[0] + return int(val) if val else 1 + + def get_next(self): """ Usage: - with yield stream_id_gen.get_next as stream_id: + with stream_id_gen.get_next() as stream_id: # ... persist event ... """ with self._lock: @@ -106,10 +91,10 @@ class StreamIdGenerator(object): return manager() - def get_next_mult(self, store, n): + def get_next_mult(self, n): """ Usage: - with yield stream_id_gen.get_next(store, n) as stream_ids: + with stream_id_gen.get_next(n) as stream_ids: # ... persist events ... """ with self._lock: @@ -130,7 +115,7 @@ class StreamIdGenerator(object): return manager() - def get_max_token(self, store): + def get_max_token(self): """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. """ @@ -139,13 +124,3 @@ class StreamIdGenerator(object): return self._unfinished_ids[0] - 1 return self._current_max - - def _get_or_compute_current_max(self, txn): - with self._lock: - txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table)) - rows = txn.fetchall() - val, = rows[0] - - self._current_max = int(val) if val else 1 - - return self._current_max diff --git a/synapse/types.py b/synapse/types.py index 2095837ba6..d5bd95cbd3 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -73,6 +73,14 @@ class DomainSpecificString( """Return a string encoding the fields of the structure object.""" return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain) + @classmethod + def is_valid(cls, s): + try: + cls.from_string(s) + return True + except: + return False + __str__ = to_string @classmethod diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 133671e238..3b9da5b34a 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -42,7 +42,7 @@ class Clock(object): def time_msec(self): """Returns the current system time in miliseconds since epoch.""" - return self.time() * 1000 + return int(self.time() * 1000) def looping_call(self, f, msec): l = task.LoopingCall(f) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 277854ccbc..35544b19fd 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -28,6 +28,7 @@ from twisted.internet import defer from collections import OrderedDict +import os import functools import inspect import threading @@ -38,6 +39,9 @@ logger = logging.getLogger(__name__) _CacheSentinel = object() +CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) + + class Cache(object): def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False): @@ -140,6 +144,8 @@ class CacheDescriptor(object): """ def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False, inlineCallbacks=False): + max_entries = int(max_entries * CACHE_SIZE_FACTOR) + self.orig = orig if inlineCallbacks: diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 62cae99649..e863a8f8a9 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.util.caches import cache_counter, caches_by_name + import logging @@ -47,6 +49,8 @@ class ExpiringCache(object): self._cache = {} + caches_by_name[cache_name] = self._cache + def start(self): if not self._expiry_ms: # Don't bother starting the loop if things never expire @@ -72,7 +76,12 @@ class ExpiringCache(object): self._cache.pop(k) def __getitem__(self, key): - entry = self._cache[key] + try: + entry = self._cache[key] + cache_counter.inc_hits(self._cache_name) + except KeyError: + cache_counter.inc_misses(self._cache_name) + raise if self._reset_expiry_on_get: entry.time = self._clock.time_msec() @@ -105,9 +114,12 @@ class ExpiringCache(object): logger.debug( "[%s] _prune_cache before: %d, after len: %d", - self._cache_name, begin_length, len(self._cache.keys()) + self._cache_name, begin_length, len(self._cache) ) + def __len__(self): + return len(self._cache) + class _CacheEntry(object): def __init__(self, time, value): diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index b37f1c0725..a1aec7aa55 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -18,11 +18,15 @@ from synapse.util.caches import cache_counter, caches_by_name from blist import sorteddict import logging +import os logger = logging.getLogger(__name__) +CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) + + class StreamChangeCache(object): """Keeps track of the stream positions of the latest change in a set of entities. @@ -33,7 +37,7 @@ class StreamChangeCache(object): old then the cache will simply return all given entities. """ def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache={}): - self._max_size = max_size + self._max_size = int(max_size * CACHE_SIZE_FACTOR) self._entity_to_key = {} self._cache = sorteddict() self._earliest_known_stream_pos = current_stream_pos @@ -85,6 +89,22 @@ class StreamChangeCache(object): return result + def get_all_entities_changed(self, stream_pos): + """Returns all entites that have had new things since the given + position. If the position is too old it will return None. + """ + assert type(stream_pos) is int + + if stream_pos >= self._earliest_known_stream_pos: + keys = self._cache.keys() + i = keys.bisect_right(stream_pos) + + return ( + self._cache[k] for k in keys[i:] + ) + else: + return None + def entity_has_changed(self, entity, stream_pos): """Informs the cache that the entity has been changed at the given position. diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py new file mode 100644 index 0000000000..7412fc57a4 --- /dev/null +++ b/synapse/util/wheel_timer.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class _Entry(object): + __slots__ = ["end_key", "queue"] + + def __init__(self, end_key): + self.end_key = end_key + self.queue = [] + + +class WheelTimer(object): + """Stores arbitrary objects that will be returned after their timers have + expired. + """ + + def __init__(self, bucket_size=5000): + """ + Args: + bucket_size (int): Size of buckets in ms. Corresponds roughly to the + accuracy of the timer. + """ + self.bucket_size = bucket_size + self.entries = [] + self.current_tick = 0 + + def insert(self, now, obj, then): + """Inserts object into timer. + + Args: + now (int): Current time in msec + obj (object): Object to be inserted + then (int): When to return the object strictly after. + """ + then_key = int(then / self.bucket_size) + 1 + + if self.entries: + min_key = self.entries[0].end_key + max_key = self.entries[-1].end_key + + if then_key <= max_key: + # The max here is to protect against inserts for times in the past + self.entries[max(min_key, then_key) - min_key].queue.append(obj) + return + + next_key = int(now / self.bucket_size) + 1 + if self.entries: + last_key = self.entries[-1].end_key + else: + last_key = next_key + + # Handle the case when `then` is in the past and `entries` is empty. + then_key = max(last_key, then_key) + + # Add empty entries between the end of the current list and when we want + # to insert. This ensures there are no gaps. + self.entries.extend( + _Entry(key) for key in xrange(last_key, then_key + 1) + ) + + self.entries[-1].queue.append(obj) + + def fetch(self, now): + """Fetch any objects that have timed out + + Args: + now (ms): Current time in msec + + Returns: + list: List of objects that have timed out + """ + now_key = int(now / self.bucket_size) + + ret = [] + while self.entries and self.entries[0].end_key <= now_key: + ret.extend(self.entries.pop(0).queue) + + return ret + + def __len__(self): + l = 0 + for entry in self.entries: + l += len(entry.queue) + return l diff --git a/tests/__init__.py b/tests/__init__.py index d0e9399dda..bfebb0f644 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -12,4 +12,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 474c5c418f..7e7b0b4b1d 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -281,9 +281,9 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("user_id = %s" % (user,)) - macaroon.add_first_party_caveat("time < 1") # ms + macaroon.add_first_party_caveat("time < 1") # ms - self.hs.clock.now = 5000 # seconds + self.hs.clock.now = 5000 # seconds yield self.auth.get_user_from_macaroon(macaroon.serialize()) # TODO(daniel): Turn on the check that we validate expiration, when we diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index ceb0089268..dcb6c5bc31 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -21,7 +21,6 @@ from tests.utils import ( MockHttpResource, DeferredMockCallable, setup_test_homeserver ) -from synapse.types import UserID from synapse.api.filtering import Filter from synapse.events import FrozenEvent @@ -356,7 +355,6 @@ class FilteringTestCase(unittest.TestCase): "types": ["m.*"] } } - user = UserID.from_string("@" + user_localpart + ":test") filter_id = yield self.datastore.add_user_filter( user_localpart=user_localpart, user_filter=user_filter_json, @@ -411,7 +409,6 @@ class FilteringTestCase(unittest.TestCase): } } } - user = UserID.from_string("@" + user_localpart + ":test") filter_id = yield self.datastore.add_user_filter( user_localpart=user_localpart, user_filter=user_filter_json, @@ -440,7 +437,6 @@ class FilteringTestCase(unittest.TestCase): } } } - user = UserID.from_string("@" + user_localpart + ":test") filter_id = yield self.datastore.add_user_filter( user_localpart=user_localpart, user_filter=user_filter_json, @@ -460,6 +456,22 @@ class FilteringTestCase(unittest.TestCase): results = user_filter.filter_room_state(events) self.assertEquals([], results) + def test_filter_rooms(self): + definition = { + "rooms": ["!allowed:example.com", "!excluded:example.com"], + "not_rooms": ["!excluded:example.com"], + } + + room_ids = [ + "!allowed:example.com", # Allowed because in rooms and not in not_rooms. + "!excluded:example.com", # Disallowed because in not_rooms. + "!not_included:example.com", # Disallowed because not in rooms. + ] + + filtered_room_ids = list(Filter(definition).filter_rooms(room_ids)) + + self.assertEquals(filtered_room_ids, ["!allowed:example.com"]) + @defer.inlineCallbacks def test_add_filter(self): user_filter_json = { @@ -476,12 +488,12 @@ class FilteringTestCase(unittest.TestCase): ) self.assertEquals(filter_id, 0) - self.assertEquals(user_filter_json, - (yield self.datastore.get_user_filter( + self.assertEquals(user_filter_json, ( + yield self.datastore.get_user_filter( user_localpart=user_localpart, filter_id=0, - )) - ) + ) + )) @defer.inlineCallbacks def test_get_filter(self): @@ -504,3 +516,5 @@ class FilteringTestCase(unittest.TestCase): ) self.assertEquals(filter.get_filter_json(), user_filter_json) + + self.assertRegexpMatches(repr(filter), r"<FilterCollection \{.*\}>") diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index dd0bc19ecf..c45b59b36c 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -2,6 +2,7 @@ from synapse.api.ratelimiting import Ratelimiter from tests import unittest + class TestRatelimiter(unittest.TestCase): def test_allowed(self): diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index ef48bbc296..d6cc1881e9 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -14,7 +14,7 @@ # limitations under the License. from synapse.appservice import ApplicationService -from mock import Mock, PropertyMock +from mock import Mock from tests import unittest diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index c9c2d36210..631a229332 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.appservice import ApplicationServiceState, AppServiceTransaction +from synapse.appservice import ApplicationServiceState from synapse.appservice.scheduler import ( _ServiceQueuer, _TransactionController, _Recoverer ) @@ -235,7 +235,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): srv_2_event2 = Mock(event_id="srv2b") send_return_list = [srv_1_defer, srv_2_defer] - self.txn_ctrl.send = Mock(side_effect=lambda x,y: send_return_list.pop(0)) + self.txn_ctrl.send = Mock(side_effect=lambda x, y: send_return_list.pop(0)) # send events for different ASes and make sure they are sent self.queuer.enqueue(srv1, srv_1_event) diff --git a/tests/config/test_load.py b/tests/config/test_load.py index fbbbf93fef..bf46233c5c 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -60,6 +60,22 @@ class ConfigLoadingTestCase(unittest.TestCase): config2 = HomeServerConfig.load_config("", ["-c", self.file]) self.assertEqual(config1.macaroon_secret_key, config2.macaroon_secret_key) + def test_disable_registration(self): + self.generate_config() + self.add_lines_to_config([ + "enable_registration: true", + "disable_registration: true", + ]) + # Check that disable_registration clobbers enable_registration. + config = HomeServerConfig.load_config("", ["-c", self.file]) + self.assertFalse(config.enable_registration) + + # Check that either config value is clobbered by the command line. + config = HomeServerConfig.load_config("", [ + "-c", self.file, "--enable-registration" + ]) + self.assertTrue(config.enable_registration) + def generate_config(self): HomeServerConfig.load_config("", [ "--generate-config", @@ -76,3 +92,8 @@ class ConfigLoadingTestCase(unittest.TestCase): contents = [l for l in contents if needle not in l] with open(self.file, "w") as f: f.write("".join(contents)) + + def add_lines_to_config(self, lines): + with open(self.file, "a") as f: + for line in lines: + f.write(line + "\n") diff --git a/tests/crypto/__init__.py b/tests/crypto/__init__.py index d0e9399dda..bfebb0f644 100644 --- a/tests/crypto/__init__.py +++ b/tests/crypto/__init__.py @@ -12,4 +12,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 894d0c3845..fb0953c4ec 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -19,6 +19,7 @@ from .. import unittest from synapse.events import FrozenEvent from synapse.events.utils import prune_event + class PruneEventTestCase(unittest.TestCase): """ Asserts that a new event constructed with `evdict` will look like `matchdict` when it is redacted. """ diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index ba6e2c640e..7ddbbb9b4a 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -129,8 +129,6 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.assertEquals(result.room_id, room_id) self.assertEquals(result.servers, servers) - - def _mkservice(self, is_interested): service = Mock() service.is_interested = Mock(return_value=is_interested) diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 2f21bf91e5..21077cbe9a 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -15,7 +15,6 @@ import pymacaroons -from mock import Mock, NonCallableMock from synapse.handlers.auth import AuthHandler from tests import unittest from tests.utils import setup_test_homeserver diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 447a22b5fc..87c795fcfa 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,1326 +15,370 @@ from tests import unittest -from twisted.internet import defer, reactor -from mock import Mock, call, ANY, NonCallableMock -import json - -from tests.utils import ( - MockHttpResource, MockClock, DeferredMockCallable, setup_test_homeserver -) +from mock import Mock, call from synapse.api.constants import PresenceState -from synapse.api.errors import SynapseError -from synapse.handlers.presence import PresenceHandler, UserPresenceCache -from synapse.streams.config import SourcePaginationConfig -from synapse.types import UserID - -OFFLINE = PresenceState.OFFLINE -UNAVAILABLE = PresenceState.UNAVAILABLE -ONLINE = PresenceState.ONLINE - - -def _expect_edu(destination, edu_type, content, origin="test"): - return { - "origin": origin, - "origin_server_ts": 1000000, - "pdus": [], - "edus": [ - { - "edu_type": edu_type, - "content": content, - } - ], - "pdu_failures": [], - } - -def _make_edu_json(origin, edu_type, content): - return json.dumps(_expect_edu("test", edu_type, content, origin=origin)) - - -class JustPresenceHandlers(object): - def __init__(self, hs): - self.presence_handler = PresenceHandler(hs) - - -class PresenceTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - self.clock = MockClock() - - self.mock_federation_resource = MockHttpResource() - - self.mock_http_client = Mock(spec=[]) - self.mock_http_client.put_json = DeferredMockCallable() - - hs_kwargs = {} - if hasattr(self, "make_datastore_mock"): - hs_kwargs["datastore"] = self.make_datastore_mock() - - hs = yield setup_test_homeserver( - clock=self.clock, - handlers=None, - resource_for_federation=self.mock_federation_resource, - http_client=self.mock_http_client, - keyring=Mock(), - **hs_kwargs - ) - hs.handlers = JustPresenceHandlers(hs) - - self.datastore = hs.get_datastore() - - self.setUp_roommemberhandler_mocks(hs.handlers) - - self.handler = hs.get_handlers().presence_handler - self.event_source = hs.get_event_sources().sources["presence"] - - self.distributor = hs.get_distributor() - self.distributor.declare("user_joined_room") - - yield self.setUp_users(hs) - - def setUp_roommemberhandler_mocks(self, handlers): - self.room_id = "a-room" - self.room_members = [] - - room_member_handler = handlers.room_member_handler = Mock(spec=[ - "get_joined_rooms_for_user", - "get_room_members", - "fetch_room_distributions_into", - ]) - self.room_member_handler = room_member_handler - - def get_rooms_for_user(user): - if user in self.room_members: - return defer.succeed([self.room_id]) - else: - return defer.succeed([]) - room_member_handler.get_joined_rooms_for_user = get_rooms_for_user - - def get_room_members(room_id): - if room_id == self.room_id: - return defer.succeed(self.room_members) - else: - return defer.succeed([]) - room_member_handler.get_room_members = get_room_members - - @defer.inlineCallbacks - def fetch_room_distributions_into(room_id, localusers=None, - remotedomains=None, ignore_user=None): - - members = yield get_room_members(room_id) - for member in members: - if ignore_user is not None and member == ignore_user: - continue - - if member.is_mine: - if localusers is not None: - localusers.add(member) - else: - if remotedomains is not None: - remotedomains.add(member.domain) - room_member_handler.fetch_room_distributions_into = ( - fetch_room_distributions_into) - - self.setUp_datastore_room_mocks(self.datastore) - - def setUp_datastore_room_mocks(self, datastore): - def get_room_hosts(room_id): - if room_id == self.room_id: - hosts = set([u.domain for u in self.room_members]) - return defer.succeed(hosts) - else: - return defer.succeed([]) - datastore.get_joined_hosts_for_room = get_room_hosts - - def user_rooms_intersect(userlist): - room_member_ids = map(lambda u: u.to_string(), self.room_members) - - shared = all(map(lambda i: i in room_member_ids, userlist)) - return defer.succeed(shared) - datastore.user_rooms_intersect = user_rooms_intersect - - @defer.inlineCallbacks - def setUp_users(self, hs): - # Some local users to test with - self.u_apple = UserID.from_string("@apple:test") - self.u_banana = UserID.from_string("@banana:test") - self.u_clementine = UserID.from_string("@clementine:test") - - for u in self.u_apple, self.u_banana, self.u_clementine: - yield self.datastore.create_presence(u.localpart) - - yield self.datastore.set_presence_state( - self.u_apple.localpart, {"state": ONLINE, "status_msg": "Online"} - ) - - # ID of a local user that does not exist - self.u_durian = UserID.from_string("@durian:test") - - # A remote user - self.u_cabbage = UserID.from_string("@cabbage:elsewhere") - - -class MockedDatastorePresenceTestCase(PresenceTestCase): - def make_datastore_mock(self): - datastore = Mock(spec=[ - # Bits that Federation needs - "prep_send_transaction", - "delivered_txn", - "get_received_txn_response", - "set_received_txn_response", - "get_destination_retry_timings", - ]) - - self.setUp_datastore_federation_mocks(datastore) - self.setUp_datastore_presence_mocks(datastore) - - return datastore - - def setUp_datastore_federation_mocks(self, datastore): - retry_timings_res = { - "destination": "", - "retry_last_ts": 0, - "retry_interval": 0, - } - datastore.get_destination_retry_timings.return_value = ( - defer.succeed(retry_timings_res) - ) - - def get_received_txn_response(*args): - return defer.succeed(None) - datastore.get_received_txn_response = get_received_txn_response - - def setUp_datastore_presence_mocks(self, datastore): - self.current_user_state = { - "apple": OFFLINE, - "banana": OFFLINE, - "clementine": OFFLINE, - "fig": OFFLINE, - } - - def get_presence_state(user_localpart): - return defer.succeed( - {"state": self.current_user_state[user_localpart], - "status_msg": None, - "mtime": 123456000} - ) - datastore.get_presence_state = get_presence_state - - def set_presence_state(user_localpart, new_state): - was = self.current_user_state[user_localpart] - self.current_user_state[user_localpart] = new_state["state"] - return defer.succeed({"state": was}) - datastore.set_presence_state = set_presence_state - - def get_presence_list(user_localpart, accepted): - if not user_localpart in self.PRESENCE_LIST: - return defer.succeed([]) - return defer.succeed([ - {"observed_user_id": u, "accepted": accepted} for u in - self.PRESENCE_LIST[user_localpart]]) - datastore.get_presence_list = get_presence_list - - def is_presence_visible(observed_localpart, observer_userid): - return True - datastore.is_presence_visible = is_presence_visible - - @defer.inlineCallbacks - def setUp_users(self, hs): - # Some local users to test with - self.u_apple = UserID.from_string("@apple:test") - self.u_banana = UserID.from_string("@banana:test") - self.u_clementine = UserID.from_string("@clementine:test") - self.u_durian = UserID.from_string("@durian:test") - self.u_elderberry = UserID.from_string("@elderberry:test") - self.u_fig = UserID.from_string("@fig:test") - - # Remote user - self.u_onion = UserID.from_string("@onion:farm") - self.u_potato = UserID.from_string("@potato:remote") - - yield - - -class PresenceStateTestCase(PresenceTestCase): - """ Tests presence management. """ - @defer.inlineCallbacks - def setUp(self): - yield super(PresenceStateTestCase, self).setUp() - - self.mock_start = Mock() - self.mock_stop = Mock() - - self.handler.start_polling_presence = self.mock_start - self.handler.stop_polling_presence = self.mock_stop - - @defer.inlineCallbacks - def test_get_my_state(self): - state = yield self.handler.get_state( - target_user=self.u_apple, auth_user=self.u_apple - ) - - self.assertEquals( - {"presence": ONLINE, "status_msg": "Online"}, - state - ) - - @defer.inlineCallbacks - def test_get_allowed_state(self): - yield self.datastore.allow_presence_visible( - observed_localpart=self.u_apple.localpart, - observer_userid=self.u_banana.to_string(), - ) - - state = yield self.handler.get_state( - target_user=self.u_apple, auth_user=self.u_banana - ) +from synapse.handlers.presence import ( + handle_update, handle_timeout, + IDLE_TIMER, SYNC_ONLINE_TIMEOUT, LAST_ACTIVE_GRANULARITY, FEDERATION_TIMEOUT, + FEDERATION_PING_INTERVAL, +) +from synapse.storage.presence import UserPresenceState - self.assertEquals( - {"presence": ONLINE, "status_msg": "Online"}, - state - ) - @defer.inlineCallbacks - def test_get_same_room_state(self): - self.room_members = [self.u_apple, self.u_clementine] +class PresenceUpdateTestCase(unittest.TestCase): + def test_offline_to_online(self): + wheel_timer = Mock() + user_id = "@foo:bar" + now = 5000000 - state = yield self.handler.get_state( - target_user=self.u_apple, auth_user=self.u_clementine + prev_state = UserPresenceState.default(user_id) + new_state = prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=now, ) - self.assertEquals( - {"presence": ONLINE, "status_msg": "Online"}, - state + state, persist_and_notify, federation_ping = handle_update( + prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now ) - @defer.inlineCallbacks - def test_get_disallowed_state(self): - self.room_members = [] + self.assertTrue(persist_and_notify) + self.assertTrue(state.currently_active) + self.assertEquals(new_state.state, state.state) + self.assertEquals(new_state.status_msg, state.status_msg) + self.assertEquals(state.last_federation_update_ts, now) - yield self.assertFailure( - self.handler.get_state( - target_user=self.u_apple, auth_user=self.u_clementine + self.assertEquals(wheel_timer.insert.call_count, 3) + wheel_timer.insert.assert_has_calls([ + call( + now=now, + obj=user_id, + then=new_state.last_active_ts + IDLE_TIMER ), - SynapseError - ) - - @defer.inlineCallbacks - def test_set_my_state(self): - yield self.handler.set_state( - target_user=self.u_apple, auth_user=self.u_apple, - state={"presence": UNAVAILABLE, "status_msg": "Away"}) - - self.assertEquals( - {"state": UNAVAILABLE, - "status_msg": "Away", - "mtime": 1000000}, - (yield self.datastore.get_presence_state(self.u_apple.localpart)) - ) - - self.mock_start.assert_called_with(self.u_apple, - state={ - "presence": UNAVAILABLE, - "status_msg": "Away", - "last_active": 1000000, # MockClock - }) - - yield self.handler.set_state( - target_user=self.u_apple, auth_user=self.u_apple, - state={"presence": OFFLINE}) - - self.mock_stop.assert_called_with(self.u_apple) - - -class PresenceInvitesTestCase(PresenceTestCase): - """ Tests presence management. """ - @defer.inlineCallbacks - def setUp(self): - yield super(PresenceInvitesTestCase, self).setUp() - - self.mock_start = Mock() - self.mock_stop = Mock() - - self.handler.start_polling_presence = self.mock_start - self.handler.stop_polling_presence = self.mock_stop - - @defer.inlineCallbacks - def test_invite_local(self): - # TODO(paul): This test will likely break if/when real auth permissions - # are added; for now the HS will always accept any invite - - yield self.handler.send_presence_invite( - observer_user=self.u_apple, observed_user=self.u_banana) - - self.assertEquals( - [{"observed_user_id": "@banana:test", "accepted": 1}], - (yield self.datastore.get_presence_list(self.u_apple.localpart)) - ) - self.assertTrue( - (yield self.datastore.is_presence_visible( - observed_localpart=self.u_banana.localpart, - observer_userid=self.u_apple.to_string(), - )) - ) - - self.mock_start.assert_called_with( - self.u_apple, target_user=self.u_banana) - - @defer.inlineCallbacks - def test_invite_local_nonexistant(self): - yield self.handler.send_presence_invite( - observer_user=self.u_apple, observed_user=self.u_durian) - - self.assertEquals( - [], - (yield self.datastore.get_presence_list(self.u_apple.localpart)) - ) - - @defer.inlineCallbacks - def test_invite_remote(self): - # Use a different destination, otherwise retry logic might fail the - # request - u_rocket = UserID.from_string("@rocket:there") - - put_json = self.mock_http_client.put_json - put_json.expect_call_and_return( - call("there", - path="/_matrix/federation/v1/send/1000000/", - data=_expect_edu("there", "m.presence_invite", - content={ - "observer_user": "@apple:test", - "observed_user": "@rocket:there", - } - ), - json_data_callback=ANY, - long_retries=True, + call( + now=now, + obj=user_id, + then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT ), - defer.succeed((200, "OK")) - ) + call( + now=now, + obj=user_id, + then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY + ), + ], any_order=True) - yield self.handler.send_presence_invite( - observer_user=self.u_apple, observed_user=u_rocket) + def test_online_to_online(self): + wheel_timer = Mock() + user_id = "@foo:bar" + now = 5000000 - self.assertEquals( - [{"observed_user_id": "@rocket:there", "accepted": 0}], - (yield self.datastore.get_presence_list(self.u_apple.localpart)) + prev_state = UserPresenceState.default(user_id) + prev_state = prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=now, + currently_active=True, ) - yield put_json.await_calls() - - @defer.inlineCallbacks - def test_accept_remote(self): - # TODO(paul): This test will likely break if/when real auth permissions - # are added; for now the HS will always accept any invite - - # Use a different destination, otherwise retry logic might fail the - # request - u_rocket = UserID.from_string("@rocket:moon") - - put_json = self.mock_http_client.put_json - put_json.expect_call_and_return( - call("moon", - path="/_matrix/federation/v1/send/1000000/", - data=_expect_edu("moon", "m.presence_accept", - content={ - "observer_user": "@rocket:moon", - "observed_user": "@apple:test", - } - ), - json_data_callback=ANY, - long_retries=True, - ), - defer.succeed((200, "OK")) + new_state = prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=now, ) - yield self.mock_federation_resource.trigger("PUT", - "/_matrix/federation/v1/send/1000000/", - _make_edu_json("elsewhere", "m.presence_invite", - content={ - "observer_user": "@rocket:moon", - "observed_user": "@apple:test", - } - ) + state, persist_and_notify, federation_ping = handle_update( + prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now ) - self.assertTrue( - (yield self.datastore.is_presence_visible( - observed_localpart=self.u_apple.localpart, - observer_userid=u_rocket.to_string(), - )) - ) + self.assertFalse(persist_and_notify) + self.assertTrue(federation_ping) + self.assertTrue(state.currently_active) + self.assertEquals(new_state.state, state.state) + self.assertEquals(new_state.status_msg, state.status_msg) + self.assertEquals(state.last_federation_update_ts, now) - yield put_json.await_calls() - - @defer.inlineCallbacks - def test_invited_remote_nonexistant(self): - # Use a different destination, otherwise retry logic might fail the - # request - u_rocket = UserID.from_string("@rocket:sun") - - put_json = self.mock_http_client.put_json - put_json.expect_call_and_return( - call("sun", - path="/_matrix/federation/v1/send/1000000/", - data=_expect_edu("sun", "m.presence_deny", - content={ - "observer_user": "@rocket:sun", - "observed_user": "@durian:test", - } - ), - json_data_callback=ANY, - long_retries=True, + self.assertEquals(wheel_timer.insert.call_count, 3) + wheel_timer.insert.assert_has_calls([ + call( + now=now, + obj=user_id, + then=new_state.last_active_ts + IDLE_TIMER ), - defer.succeed((200, "OK")) - ) - - yield self.mock_federation_resource.trigger("PUT", - "/_matrix/federation/v1/send/1000000/", - _make_edu_json("sun", "m.presence_invite", - content={ - "observer_user": "@rocket:sun", - "observed_user": "@durian:test", - } - ) - ) + call( + now=now, + obj=user_id, + then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT + ), + call( + now=now, + obj=user_id, + then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY + ), + ], any_order=True) - yield put_json.await_calls() + def test_online_to_online_last_active(self): + wheel_timer = Mock() + user_id = "@foo:bar" + now = 5000000 - @defer.inlineCallbacks - def test_accepted_remote(self): - yield self.datastore.add_presence_list_pending( - observer_localpart=self.u_apple.localpart, - observed_userid=self.u_cabbage.to_string(), + prev_state = UserPresenceState.default(user_id) + prev_state = prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=now - LAST_ACTIVE_GRANULARITY - 1, + currently_active=True, ) - yield self.mock_federation_resource.trigger("PUT", - "/_matrix/federation/v1/send/1000000/", - _make_edu_json("elsewhere", "m.presence_accept", - content={ - "observer_user": "@apple:test", - "observed_user": "@cabbage:elsewhere", - } - ) + new_state = prev_state.copy_and_replace( + state=PresenceState.ONLINE, ) - self.assertEquals( - [{"observed_user_id": "@cabbage:elsewhere", "accepted": 1}], - (yield self.datastore.get_presence_list(self.u_apple.localpart)) + state, persist_and_notify, federation_ping = handle_update( + prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now ) - self.mock_start.assert_called_with( - self.u_apple, target_user=self.u_cabbage) - - @defer.inlineCallbacks - def test_denied_remote(self): - yield self.datastore.add_presence_list_pending( - observer_localpart=self.u_apple.localpart, - observed_userid="@eggplant:elsewhere", - ) + self.assertTrue(persist_and_notify) + self.assertFalse(state.currently_active) + self.assertEquals(new_state.state, state.state) + self.assertEquals(new_state.status_msg, state.status_msg) + self.assertEquals(state.last_federation_update_ts, now) - yield self.mock_federation_resource.trigger("PUT", - "/_matrix/federation/v1/send/1000000/", - _make_edu_json("elsewhere", "m.presence_deny", - content={ - "observer_user": "@apple:test", - "observed_user": "@eggplant:elsewhere", - } + self.assertEquals(wheel_timer.insert.call_count, 2) + wheel_timer.insert.assert_has_calls([ + call( + now=now, + obj=user_id, + then=new_state.last_active_ts + IDLE_TIMER + ), + call( + now=now, + obj=user_id, + then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT ) - ) - - self.assertEquals( - [], - (yield self.datastore.get_presence_list(self.u_apple.localpart)) - ) - - @defer.inlineCallbacks - def test_drop_local(self): - yield self.datastore.add_presence_list_pending( - observer_localpart=self.u_apple.localpart, - observed_userid=self.u_banana.to_string(), - ) - yield self.datastore.set_presence_list_accepted( - observer_localpart=self.u_apple.localpart, - observed_userid=self.u_banana.to_string(), - ) - - yield self.handler.drop( - observer_user=self.u_apple, - observed_user=self.u_banana, - ) - - self.assertEquals( - [], - (yield self.datastore.get_presence_list(self.u_apple.localpart)) - ) - - self.mock_stop.assert_called_with( - self.u_apple, target_user=self.u_banana) - - @defer.inlineCallbacks - def test_drop_remote(self): - yield self.datastore.add_presence_list_pending( - observer_localpart=self.u_apple.localpart, - observed_userid=self.u_cabbage.to_string(), - ) - yield self.datastore.set_presence_list_accepted( - observer_localpart=self.u_apple.localpart, - observed_userid=self.u_cabbage.to_string(), - ) - - yield self.handler.drop( - observer_user=self.u_apple, - observed_user=self.u_cabbage, - ) - - self.assertEquals( - [], - (yield self.datastore.get_presence_list(self.u_apple.localpart)) - ) - - @defer.inlineCallbacks - def test_get_presence_list(self): - yield self.datastore.add_presence_list_pending( - observer_localpart=self.u_apple.localpart, - observed_userid=self.u_banana.to_string(), - ) - yield self.datastore.set_presence_list_accepted( - observer_localpart=self.u_apple.localpart, - observed_userid=self.u_banana.to_string(), - ) - - presence = yield self.handler.get_presence_list( - observer_user=self.u_apple) - - self.assertEquals([ - {"observed_user": self.u_banana, - "presence": OFFLINE, - "accepted": 1}, - ], presence) - - -class PresencePushTestCase(MockedDatastorePresenceTestCase): - """ Tests steady-state presence status updates. - - They assert that presence state update messages are pushed around the place - when users change state, presuming that the watches are all established. - - These tests are MASSIVELY fragile currently as they poke internals of the - presence handler; namely the _local_pushmap and _remote_recvmap. - BE WARNED... - """ - PRESENCE_LIST = { - 'apple': [ "@banana:test", "@clementine:test" ], - 'banana': [ "@apple:test" ], - } - - @defer.inlineCallbacks - def test_push_local(self): - self.room_members = [self.u_apple, self.u_elderberry] - - self.datastore.set_presence_state.return_value = defer.succeed( - {"state": ONLINE} - ) - - # TODO(paul): Gut-wrenching - self.handler._user_cachemap[self.u_apple] = UserPresenceCache() - self.handler._user_cachemap[self.u_apple].update( - {"presence": OFFLINE}, serial=0 - ) - apple_set = self.handler._local_pushmap.setdefault("apple", set()) - apple_set.add(self.u_banana) - apple_set.add(self.u_clementine) - - self.assertEquals(self.event_source.get_current_key(), 0) - - yield self.handler.set_state(self.u_apple, self.u_apple, - {"presence": ONLINE} - ) - - # Apple sees self-reflection even without room_id - (events, _) = yield self.event_source.get_new_events( - user=self.u_apple, - from_key=0, - ) - - self.assertEquals(self.event_source.get_current_key(), 1) - self.assertEquals(events, - [ - {"type": "m.presence", - "content": { - "user_id": "@apple:test", - "presence": ONLINE, - "last_active_ago": 0, - }}, - ], - msg="Presence event should be visible to self-reflection" - ) - - # Apple sees self-reflection - (events, _) = yield self.event_source.get_new_events( - user=self.u_apple, - from_key=0, - room_ids=[self.room_id], - ) - - self.assertEquals(self.event_source.get_current_key(), 1) - self.assertEquals(events, - [ - {"type": "m.presence", - "content": { - "user_id": "@apple:test", - "presence": ONLINE, - "last_active_ago": 0, - }}, - ], - msg="Presence event should be visible to self-reflection" - ) - - config = SourcePaginationConfig(from_key=1, to_key=0) - (chunk, _) = yield self.event_source.get_pagination_rows( - self.u_apple, config, None - ) - self.assertEquals(chunk, - [ - {"type": "m.presence", - "content": { - "user_id": "@apple:test", - "presence": ONLINE, - "last_active_ago": 0, - }}, - ] - ) - - # Banana sees it because of presence subscription - (events, _) = yield self.event_source.get_new_events( - user=self.u_banana, - from_key=0, - room_ids=[self.room_id], - ) - - self.assertEquals(self.event_source.get_current_key(), 1) - self.assertEquals(events, - [ - {"type": "m.presence", - "content": { - "user_id": "@apple:test", - "presence": ONLINE, - "last_active_ago": 0, - }}, - ], - msg="Presence event should be visible to explicit subscribers" - ) - - # Elderberry sees it because of same room - (events, _) = yield self.event_source.get_new_events( - user=self.u_elderberry, - from_key=0, - room_ids=[self.room_id], - ) - - self.assertEquals(self.event_source.get_current_key(), 1) - self.assertEquals(events, - [ - {"type": "m.presence", - "content": { - "user_id": "@apple:test", - "presence": ONLINE, - "last_active_ago": 0, - }}, - ], - msg="Presence event should be visible to other room members" - ) - - # Durian is not in the room, should not see this event - (events, _) = yield self.event_source.get_new_events( - user=self.u_durian, - from_key=0, - room_ids=[], - ) + ], any_order=True) - self.assertEquals(self.event_source.get_current_key(), 1) - self.assertEquals(events, [], - msg="Presence event should not be visible to others" - ) + def test_remote_ping_timer(self): + wheel_timer = Mock() + user_id = "@foo:bar" + now = 5000000 - presence = yield self.handler.get_presence_list( - observer_user=self.u_apple, accepted=True) - - self.assertEquals( - [ - {"observed_user": self.u_banana, - "presence": OFFLINE, - "accepted": True}, - {"observed_user": self.u_clementine, - "presence": OFFLINE, - "accepted": True}, - ], - presence + prev_state = UserPresenceState.default(user_id) + prev_state = prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=now, ) - # TODO(paul): Gut-wrenching - banana_set = self.handler._local_pushmap.setdefault("banana", set()) - banana_set.add(self.u_apple) - - yield self.handler.set_state(self.u_banana, self.u_banana, - {"presence": ONLINE} + new_state = prev_state.copy_and_replace( + state=PresenceState.ONLINE, ) - self.clock.advance_time(2) - - presence = yield self.handler.get_presence_list( - observer_user=self.u_apple, accepted=True) - - self.assertEquals([ - {"observed_user": self.u_banana, - "presence": ONLINE, - "last_active_ago": 2000, - "accepted": True}, - {"observed_user": self.u_clementine, - "presence": OFFLINE, - "accepted": True}, - ], presence) - - (events, _) = yield self.event_source.get_new_events( - user=self.u_apple, - from_key=1, + state, persist_and_notify, federation_ping = handle_update( + prev_state, new_state, is_mine=False, wheel_timer=wheel_timer, now=now ) - self.assertEquals(self.event_source.get_current_key(), 2) - self.assertEquals(events, - [ - {"type": "m.presence", - "content": { - "user_id": "@banana:test", - "presence": ONLINE, - "last_active_ago": 2000 - }}, - ] - ) + self.assertFalse(persist_and_notify) + self.assertFalse(federation_ping) + self.assertFalse(state.currently_active) + self.assertEquals(new_state.state, state.state) + self.assertEquals(new_state.status_msg, state.status_msg) - @defer.inlineCallbacks - def test_push_remote(self): - put_json = self.mock_http_client.put_json - put_json.expect_call_and_return( - call("farm", - path=ANY, # Can't guarantee which txn ID will be which - data=_expect_edu("farm", "m.presence", - content={ - "push": [ - {"user_id": "@apple:test", - "presence": u"online", - "last_active_ago": 0}, - ], - } - ), - json_data_callback=ANY, - long_retries=True, + self.assertEquals(wheel_timer.insert.call_count, 1) + wheel_timer.insert.assert_has_calls([ + call( + now=now, + obj=user_id, + then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT ), - defer.succeed((200, "OK")) - ) - put_json.expect_call_and_return( - call("remote", - path=ANY, # Can't guarantee which txn ID will be which - data=_expect_edu("remote", "m.presence", - content={ - "push": [ - {"user_id": "@apple:test", - "presence": u"online", - "last_active_ago": 0}, - ], - } - ), - json_data_callback=ANY, - long_retries=True, - ), - defer.succeed((200, "OK")) - ) + ], any_order=True) - self.room_members = [self.u_apple, self.u_onion] + def test_online_to_offline(self): + wheel_timer = Mock() + user_id = "@foo:bar" + now = 5000000 - self.datastore.set_presence_state.return_value = defer.succeed( - {"state": ONLINE} + prev_state = UserPresenceState.default(user_id) + prev_state = prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=now, + currently_active=True, ) - # TODO(paul): Gut-wrenching - self.handler._user_cachemap[self.u_apple] = UserPresenceCache() - self.handler._user_cachemap[self.u_apple].update( - {"presence": OFFLINE}, serial=0 + new_state = prev_state.copy_and_replace( + state=PresenceState.OFFLINE, ) - apple_set = self.handler._remote_sendmap.setdefault("apple", set()) - apple_set.add(self.u_potato.domain) - yield self.handler.set_state(self.u_apple, self.u_apple, - {"presence": ONLINE} + state, persist_and_notify, federation_ping = handle_update( + prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now ) - yield put_json.await_calls() + self.assertTrue(persist_and_notify) + self.assertEquals(new_state.state, state.state) + self.assertEquals(state.last_federation_update_ts, now) - @defer.inlineCallbacks - def test_recv_remote(self): - self.room_members = [self.u_apple, self.u_banana, self.u_potato] + self.assertEquals(wheel_timer.insert.call_count, 0) - self.assertEquals(self.event_source.get_current_key(), 0) - - yield self.mock_federation_resource.trigger("PUT", - "/_matrix/federation/v1/send/1000000/", - _make_edu_json("elsewhere", "m.presence", - content={ - "push": [ - {"user_id": "@potato:remote", - "presence": "online", - "last_active_ago": 1000}, - ], - } - ) - ) + def test_online_to_idle(self): + wheel_timer = Mock() + user_id = "@foo:bar" + now = 5000000 - (events, _) = yield self.event_source.get_new_events( - user=self.u_apple, - from_key=0, - room_ids=[self.room_id], + prev_state = UserPresenceState.default(user_id) + prev_state = prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=now, + currently_active=True, ) - self.assertEquals(self.event_source.get_current_key(), 1) - self.assertEquals(events, - [ - {"type": "m.presence", - "content": { - "user_id": "@potato:remote", - "presence": ONLINE, - "last_active_ago": 1000, - }} - ] + new_state = prev_state.copy_and_replace( + state=PresenceState.UNAVAILABLE, ) - self.clock.advance_time(2) - - state = yield self.handler.get_state(self.u_potato, self.u_apple) - - self.assertEquals( - {"presence": ONLINE, "last_active_ago": 3000}, - state + state, persist_and_notify, federation_ping = handle_update( + prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now ) - @defer.inlineCallbacks - def test_recv_remote_offline(self): - """ Various tests relating to SYN-261 """ - - self.room_members = [self.u_apple, self.u_banana, self.u_potato] + self.assertTrue(persist_and_notify) + self.assertEquals(new_state.state, state.state) + self.assertEquals(state.last_federation_update_ts, now) + self.assertEquals(new_state.state, state.state) + self.assertEquals(new_state.status_msg, state.status_msg) - self.assertEquals(self.event_source.get_current_key(), 0) - - yield self.mock_federation_resource.trigger("PUT", - "/_matrix/federation/v1/send/1000000/", - _make_edu_json("elsewhere", "m.presence", - content={ - "push": [ - {"user_id": "@potato:remote", - "presence": "offline"}, - ], - } + self.assertEquals(wheel_timer.insert.call_count, 1) + wheel_timer.insert.assert_has_calls([ + call( + now=now, + obj=user_id, + then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT ) - ) - - self.assertEquals(self.event_source.get_current_key(), 1) - - (events, _) = yield self.event_source.get_new_events( - user=self.u_apple, - from_key=0, - room_ids=[self.room_id,] - ) - self.assertEquals(events, - [ - {"type": "m.presence", - "content": { - "user_id": "@potato:remote", - "presence": OFFLINE, - }} - ] - ) - - yield self.mock_federation_resource.trigger("PUT", - "/_matrix/federation/v1/send/1000001/", - _make_edu_json("elsewhere", "m.presence", - content={ - "push": [ - {"user_id": "@potato:remote", - "presence": "online"}, - ], - } - ) - ) - - self.assertEquals(self.event_source.get_current_key(), 2) - - (events, _) = yield self.event_source.get_new_events( - user=self.u_apple, - from_key=0, - room_ids=[self.room_id,] - ) - self.assertEquals(events, - [ - {"type": "m.presence", - "content": { - "user_id": "@potato:remote", - "presence": ONLINE, - }} - ] - ) - - @defer.inlineCallbacks - def test_join_room_local(self): - self.room_members = [self.u_apple, self.u_banana] - - self.assertEquals(self.event_source.get_current_key(), 0) - - # TODO(paul): Gut-wrenching - self.handler._user_cachemap[self.u_clementine] = UserPresenceCache() - self.handler._user_cachemap[self.u_clementine].update( - { - "presence": PresenceState.ONLINE, - "last_active": self.clock.time_msec(), - }, self.u_clementine - ) - - yield self.distributor.fire("user_joined_room", self.u_clementine, - self.room_id - ) + ], any_order=True) - self.room_members.append(self.u_clementine) - (events, _) = yield self.event_source.get_new_events( - user=self.u_apple, - from_key=0, - ) +class PresenceTimeoutTestCase(unittest.TestCase): + def test_idle_timer(self): + user_id = "@foo:bar" + now = 5000000 - self.assertEquals(self.event_source.get_current_key(), 1) - self.assertEquals(events, - [ - {"type": "m.presence", - "content": { - "user_id": "@clementine:test", - "presence": ONLINE, - "last_active_ago": 0, - }} - ] + state = UserPresenceState.default(user_id) + state = state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=now - IDLE_TIMER - 1, + last_user_sync_ts=now, ) - @defer.inlineCallbacks - def test_join_room_remote(self): - ## Sending local user state to a newly-joined remote user - put_json = self.mock_http_client.put_json - put_json.expect_call_and_return( - call("remote", - path=ANY, # Can't guarantee which txn ID will be which - data=_expect_edu("remote", "m.presence", - content={ - "push": [ - {"user_id": "@apple:test", - "presence": "online"}, - ], - } - ), - json_data_callback=ANY, - long_retries=True, - ), - defer.succeed((200, "OK")) - ) - put_json.expect_call_and_return( - call("remote", - path=ANY, # Can't guarantee which txn ID will be which - data=_expect_edu("remote", "m.presence", - content={ - "push": [ - {"user_id": "@banana:test", - "presence": "offline"}, - ], - } - ), - json_data_callback=ANY, - long_retries=True, - ), - defer.succeed((200, "OK")) + new_state = handle_timeout( + state, is_mine=True, user_to_num_current_syncs={}, now=now ) - # TODO(paul): Gut-wrenching - self.handler._user_cachemap[self.u_apple] = UserPresenceCache() - self.handler._user_cachemap[self.u_apple].update( - {"presence": PresenceState.ONLINE}, self.u_apple) - self.room_members = [self.u_apple, self.u_banana] + self.assertIsNotNone(new_state) + self.assertEquals(new_state.state, PresenceState.UNAVAILABLE) - yield self.distributor.fire("user_joined_room", self.u_potato, - self.room_id - ) + def test_sync_timeout(self): + user_id = "@foo:bar" + now = 5000000 - yield put_json.await_calls() - - ## Sending newly-joined local user state to remote users - - put_json.expect_call_and_return( - call("remote", - path="/_matrix/federation/v1/send/1000002/", - data=_expect_edu("remote", "m.presence", - content={ - "push": [ - {"user_id": "@clementine:test", - "presence": "online"}, - ], - } - ), - json_data_callback=ANY, - long_retries=True, - ), - defer.succeed((200, "OK")) + state = UserPresenceState.default(user_id) + state = state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=now, + last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1, ) - self.handler._user_cachemap[self.u_clementine] = UserPresenceCache() - self.handler._user_cachemap[self.u_clementine].update( - {"presence": ONLINE}, self.u_clementine) - self.room_members.append(self.u_potato) - - yield self.distributor.fire("user_joined_room", self.u_clementine, - self.room_id + new_state = handle_timeout( + state, is_mine=True, user_to_num_current_syncs={}, now=now ) - put_json.await_calls() - - -class PresencePollingTestCase(MockedDatastorePresenceTestCase): - """ Tests presence status polling. """ - - # For this test, we have three local users; apple is watching and is - # watched by the other two, but the others don't watch each other. - # Additionally clementine is watching a remote user. - PRESENCE_LIST = { - 'apple': [ "@banana:test", "@clementine:test" ], - 'banana': [ "@apple:test" ], - 'clementine': [ "@apple:test", "@potato:remote" ], - 'fig': [ "@potato:remote" ], - } - - @defer.inlineCallbacks - def setUp(self): - yield super(PresencePollingTestCase, self).setUp() - - self.mock_update_client = Mock() + self.assertIsNotNone(new_state) + self.assertEquals(new_state.state, PresenceState.OFFLINE) - def update(*args,**kwargs): - return defer.succeed(None) - self.mock_update_client.side_effect = update + def test_sync_online(self): + user_id = "@foo:bar" + now = 5000000 - self.handler.push_update_to_clients = self.mock_update_client - - @defer.inlineCallbacks - def test_push_local(self): - # apple goes online - yield self.handler.set_state( - target_user=self.u_apple, auth_user=self.u_apple, - state={"presence": ONLINE} + state = UserPresenceState.default(user_id) + state = state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=now - SYNC_ONLINE_TIMEOUT - 1, + last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1, ) - # apple should see both banana and clementine currently offline - self.mock_update_client.assert_has_calls([ - call(users_to_push=[self.u_apple]), - call(users_to_push=[self.u_apple]), - ], any_order=True) - - # Gut-wrenching tests - self.assertTrue("banana" in self.handler._local_pushmap) - self.assertTrue(self.u_apple in self.handler._local_pushmap["banana"]) - self.assertTrue("clementine" in self.handler._local_pushmap) - self.assertTrue(self.u_apple in self.handler._local_pushmap["clementine"]) - - self.mock_update_client.reset_mock() - - # banana goes online - yield self.handler.set_state( - target_user=self.u_banana, auth_user=self.u_banana, - state={"presence": ONLINE} + new_state = handle_timeout( + state, is_mine=True, user_to_num_current_syncs={ + user_id: 1, + }, now=now ) - # apple and banana should now both see each other online - self.mock_update_client.assert_has_calls([ - call(users_to_push=set([self.u_apple]), room_ids=[]), - call(users_to_push=[self.u_banana]), - ], any_order=True) + self.assertIsNotNone(new_state) + self.assertEquals(new_state.state, PresenceState.ONLINE) - self.assertTrue("apple" in self.handler._local_pushmap) - self.assertTrue(self.u_banana in self.handler._local_pushmap["apple"]) + def test_federation_ping(self): + user_id = "@foo:bar" + now = 5000000 - self.mock_update_client.reset_mock() - - # apple goes offline - yield self.handler.set_state( - target_user=self.u_apple, auth_user=self.u_apple, - state={"presence": OFFLINE} - ) - - # banana should now be told apple is offline - self.mock_update_client.assert_has_calls([ - call(users_to_push=set([self.u_banana, self.u_apple]), room_ids=[]), - ], any_order=True) - - self.assertFalse("banana" in self.handler._local_pushmap) - self.assertFalse("clementine" in self.handler._local_pushmap) - - @defer.inlineCallbacks - def test_remote_poll_send(self): - put_json = self.mock_http_client.put_json - put_json.expect_call_and_return( - call("remote", - path=ANY, - data=_expect_edu("remote", "m.presence", - content={ - "poll": [ "@potato:remote" ], - }, - ), - json_data_callback=ANY, - long_retries=True, - ), - defer.succeed((200, "OK")) + state = UserPresenceState.default(user_id) + state = state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=now, + last_user_sync_ts=now, + last_federation_update_ts=now - FEDERATION_PING_INTERVAL - 1, ) - put_json.expect_call_and_return( - call("remote", - path=ANY, - data=_expect_edu("remote", "m.presence", - content={ - "push": [ { - "user_id": "@clementine:test", - "presence": OFFLINE, - }], - }, - ), - json_data_callback=ANY, - long_retries=True, - ), - defer.succeed((200, "OK")) + new_state = handle_timeout( + state, is_mine=True, user_to_num_current_syncs={}, now=now ) - # clementine goes online - yield self.handler.set_state( - target_user=self.u_clementine, auth_user=self.u_clementine, - state={"presence": ONLINE} - ) + self.assertIsNotNone(new_state) + self.assertEquals(new_state, new_state) - yield put_json.await_calls() + def test_no_timeout(self): + user_id = "@foo:bar" + now = 5000000 - # Gut-wrenching tests - self.assertTrue(self.u_potato in self.handler._remote_recvmap, - msg="expected potato to be in _remote_recvmap" - ) - self.assertTrue(self.u_clementine in - self.handler._remote_recvmap[self.u_potato]) - - - put_json.expect_call_and_return( - call("remote", - path=ANY, - data=_expect_edu("remote", "m.presence", - content={ - "push": [ { - "user_id": "@fig:test", - "presence": OFFLINE, - }], - }, - ), - json_data_callback=ANY, - long_retries=True, - ), - defer.succeed((200, "OK")) + state = UserPresenceState.default(user_id) + state = state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=now, + last_user_sync_ts=now, + last_federation_update_ts=now, ) - # fig goes online; shouldn't send a second poll - yield self.handler.set_state( - target_user=self.u_fig, auth_user=self.u_fig, - state={"presence": ONLINE} + new_state = handle_timeout( + state, is_mine=True, user_to_num_current_syncs={}, now=now ) - # reactor.iterate(delay=0) + self.assertIsNone(new_state) - yield put_json.await_calls() + def test_federation_timeout(self): + user_id = "@foo:bar" + now = 5000000 - # fig goes offline - yield self.handler.set_state( - target_user=self.u_fig, auth_user=self.u_fig, - state={"presence": OFFLINE} + state = UserPresenceState.default(user_id) + state = state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=now, + last_user_sync_ts=now, + last_federation_update_ts=now - FEDERATION_TIMEOUT - 1, ) - reactor.iterate(delay=0) - - put_json.assert_had_no_calls() - - put_json.expect_call_and_return( - call("remote", - path=ANY, - data=_expect_edu("remote", "m.presence", - content={ - "unpoll": [ "@potato:remote" ], - }, - ), - json_data_callback=ANY, - long_retries=True, - ), - defer.succeed((200, "OK")) - ) - - # clementine goes offline - yield self.handler.set_state( - target_user=self.u_clementine, auth_user=self.u_clementine, - state={"presence": OFFLINE} + new_state = handle_timeout( + state, is_mine=False, user_to_num_current_syncs={}, now=now ) - yield put_json.await_calls() + self.assertIsNotNone(new_state) + self.assertEquals(new_state.state, PresenceState.OFFLINE) - self.assertFalse(self.u_potato in self.handler._remote_recvmap, - msg="expected potato not to be in _remote_recvmap" - ) + def test_last_active(self): + user_id = "@foo:bar" + now = 5000000 - @defer.inlineCallbacks - def test_remote_poll_receive(self): - put_json = self.mock_http_client.put_json - put_json.expect_call_and_return( - call("remote", - path="/_matrix/federation/v1/send/1000000/", - data=_expect_edu("remote", "m.presence", - content={ - "push": [ - {"user_id": "@banana:test", - "presence": "offline", - "status_msg": None}, - ], - }, - ), - json_data_callback=ANY, - long_retries=True, - ), - defer.succeed((200, "OK")) + state = UserPresenceState.default(user_id) + state = state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=now - LAST_ACTIVE_GRANULARITY - 1, + last_user_sync_ts=now, + last_federation_update_ts=now, ) - yield self.mock_federation_resource.trigger("PUT", - "/_matrix/federation/v1/send/1000000/", - _make_edu_json("remote", "m.presence", - content={ - "poll": [ "@banana:test" ], - }, - ) - ) - - yield put_json.await_calls() - - # Gut-wrenching tests - self.assertTrue(self.u_banana in self.handler._remote_sendmap) - - yield self.mock_federation_resource.trigger("PUT", - "/_matrix/federation/v1/send/1000001/", - _make_edu_json("remote", "m.presence", - content={ - "unpoll": [ "@banana:test" ], - } - ) + new_state = handle_timeout( + state, is_mine=True, user_to_num_current_syncs={}, now=now ) - # Gut-wrenching tests - self.assertFalse(self.u_banana in self.handler._remote_sendmap) + self.assertIsNotNone(new_state) + self.assertEquals(state, new_state) diff --git a/tests/handlers/test_presencelike.py b/tests/handlers/test_presencelike.py deleted file mode 100644 index 76f6ba5e7b..0000000000 --- a/tests/handlers/test_presencelike.py +++ /dev/null @@ -1,311 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""This file contains tests of the "presence-like" data that is shared between -presence and profiles; namely, the displayname and avatar_url.""" - -from tests import unittest -from twisted.internet import defer - -from mock import Mock, call, ANY, NonCallableMock - -from ..utils import MockClock, setup_test_homeserver - -from synapse.api.constants import PresenceState -from synapse.handlers.presence import PresenceHandler -from synapse.handlers.profile import ProfileHandler -from synapse.types import UserID - - -OFFLINE = PresenceState.OFFLINE -UNAVAILABLE = PresenceState.UNAVAILABLE -ONLINE = PresenceState.ONLINE - - -class MockReplication(object): - def __init__(self): - self.edu_handlers = {} - - def register_edu_handler(self, edu_type, handler): - self.edu_handlers[edu_type] = handler - - def register_query_handler(self, query_type, handler): - pass - - def received_edu(self, origin, edu_type, content): - self.edu_handlers[edu_type](origin, content) - - -class PresenceAndProfileHandlers(object): - def __init__(self, hs): - self.presence_handler = PresenceHandler(hs) - self.profile_handler = ProfileHandler(hs) - - -class PresenceProfilelikeDataTestCase(unittest.TestCase): - - @defer.inlineCallbacks - def setUp(self): - hs = yield setup_test_homeserver( - clock=MockClock(), - datastore=Mock(spec=[ - "set_presence_state", - "is_presence_visible", - "set_profile_displayname", - "get_rooms_for_user", - ]), - handlers=None, - resource_for_federation=Mock(), - http_client=None, - replication_layer=MockReplication(), - ratelimiter=NonCallableMock(spec_set=[ - "send_message", - ]), - ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) - hs.handlers = PresenceAndProfileHandlers(hs) - - self.datastore = hs.get_datastore() - - self.replication = hs.get_replication_layer() - self.replication.send_edu = Mock() - - def send_edu(*args, **kwargs): - # print "send_edu: %s, %s" % (args, kwargs) - return defer.succeed((200, "OK")) - self.replication.send_edu.side_effect = send_edu - - def get_profile_displayname(user_localpart): - return defer.succeed("Frank") - self.datastore.get_profile_displayname = get_profile_displayname - - def is_presence_visible(*args, **kwargs): - return defer.succeed(False) - self.datastore.is_presence_visible = is_presence_visible - - def get_profile_avatar_url(user_localpart): - return defer.succeed("http://foo") - self.datastore.get_profile_avatar_url = get_profile_avatar_url - - self.presence_list = [ - {"observed_user_id": "@banana:test", "accepted": True}, - {"observed_user_id": "@clementine:test", "accepted": True}, - ] - def get_presence_list(user_localpart, accepted=None): - return defer.succeed(self.presence_list) - self.datastore.get_presence_list = get_presence_list - - def user_rooms_intersect(userlist): - return defer.succeed(False) - self.datastore.user_rooms_intersect = user_rooms_intersect - - self.handlers = hs.get_handlers() - - self.mock_update_client = Mock() - def update(*args, **kwargs): - # print "mock_update_client: %s, %s" %(args, kwargs) - return defer.succeed(None) - self.mock_update_client.side_effect = update - - self.handlers.presence_handler.push_update_to_clients = ( - self.mock_update_client) - - hs.handlers.room_member_handler = Mock(spec=[ - "get_joined_rooms_for_user", - ]) - hs.handlers.room_member_handler.get_joined_rooms_for_user = ( - lambda u: defer.succeed([])) - - # Some local users to test with - self.u_apple = UserID.from_string("@apple:test") - self.u_banana = UserID.from_string("@banana:test") - self.u_clementine = UserID.from_string("@clementine:test") - - # Remote user - self.u_potato = UserID.from_string("@potato:remote") - - self.mock_get_joined = ( - self.datastore.get_rooms_for_user - ) - - @defer.inlineCallbacks - def test_set_my_state(self): - self.presence_list = [ - {"observed_user_id": "@banana:test", "accepted": True}, - {"observed_user_id": "@clementine:test", "accepted": True}, - ] - - mocked_set = self.datastore.set_presence_state - mocked_set.return_value = defer.succeed({"state": OFFLINE}) - - yield self.handlers.presence_handler.set_state( - target_user=self.u_apple, auth_user=self.u_apple, - state={"presence": UNAVAILABLE, "status_msg": "Away"}) - - mocked_set.assert_called_with("apple", - {"state": UNAVAILABLE, "status_msg": "Away"} - ) - - @defer.inlineCallbacks - def test_push_local(self): - def get_joined(*args): - return defer.succeed([]) - - self.mock_get_joined.side_effect = get_joined - - self.presence_list = [ - {"observed_user_id": "@banana:test", "accepted": True}, - {"observed_user_id": "@clementine:test", "accepted": True}, - ] - - self.datastore.set_presence_state.return_value = defer.succeed( - {"state": ONLINE} - ) - - # TODO(paul): Gut-wrenching - from synapse.handlers.presence import UserPresenceCache - self.handlers.presence_handler._user_cachemap[self.u_apple] = ( - UserPresenceCache() - ) - self.handlers.presence_handler._user_cachemap[self.u_apple].update( - {"presence": OFFLINE}, serial=0 - ) - apple_set = self.handlers.presence_handler._local_pushmap.setdefault( - "apple", set()) - apple_set.add(self.u_banana) - apple_set.add(self.u_clementine) - - yield self.handlers.presence_handler.set_state(self.u_apple, - self.u_apple, {"presence": ONLINE} - ) - yield self.handlers.presence_handler.set_state(self.u_banana, - self.u_banana, {"presence": ONLINE} - ) - - presence = yield self.handlers.presence_handler.get_presence_list( - observer_user=self.u_apple, accepted=True) - - self.assertEquals([ - {"observed_user": self.u_banana, - "presence": ONLINE, - "last_active_ago": 0, - "displayname": "Frank", - "avatar_url": "http://foo", - "accepted": True}, - {"observed_user": self.u_clementine, - "presence": OFFLINE, - "accepted": True} - ], presence) - - self.mock_update_client.assert_has_calls([ - call( - users_to_push={self.u_apple, self.u_banana, self.u_clementine}, - room_ids=[] - ), - ], any_order=True) - - self.mock_update_client.reset_mock() - - self.datastore.set_profile_displayname.return_value = defer.succeed( - None) - - yield self.handlers.profile_handler.set_displayname(self.u_apple, - self.u_apple, "I am an Apple") - - self.mock_update_client.assert_has_calls([ - call( - users_to_push={self.u_apple, self.u_banana, self.u_clementine}, - room_ids=[], - ), - ], any_order=True) - - @defer.inlineCallbacks - def test_push_remote(self): - self.presence_list = [ - {"observed_user_id": "@potato:remote", "accepted": True}, - ] - - self.datastore.set_presence_state.return_value = defer.succeed( - {"state": ONLINE} - ) - - # TODO(paul): Gut-wrenching - from synapse.handlers.presence import UserPresenceCache - self.handlers.presence_handler._user_cachemap[self.u_apple] = ( - UserPresenceCache() - ) - self.handlers.presence_handler._user_cachemap[self.u_apple].update( - {"presence": OFFLINE}, serial=0 - ) - apple_set = self.handlers.presence_handler._remote_sendmap.setdefault( - "apple", set()) - apple_set.add(self.u_potato.domain) - - yield self.handlers.presence_handler.set_state(self.u_apple, - self.u_apple, {"presence": ONLINE} - ) - - self.replication.send_edu.assert_called_with( - destination="remote", - edu_type="m.presence", - content={ - "push": [ - {"user_id": "@apple:test", - "presence": "online", - "last_active_ago": 0, - "displayname": "Frank", - "avatar_url": "http://foo"}, - ], - }, - ) - - @defer.inlineCallbacks - def test_recv_remote(self): - self.presence_list = [ - {"observed_user_id": "@banana:test"}, - {"observed_user_id": "@clementine:test"}, - ] - - # TODO(paul): Gut-wrenching - potato_set = self.handlers.presence_handler._remote_recvmap.setdefault( - self.u_potato, set() - ) - potato_set.add(self.u_apple) - - yield self.replication.received_edu( - "remote", "m.presence", { - "push": [ - {"user_id": "@potato:remote", - "presence": "online", - "displayname": "Frank", - "avatar_url": "http://foo"}, - ], - } - ) - - self.mock_update_client.assert_called_with( - users_to_push=set([self.u_apple]), - room_ids=[], - ) - - state = yield self.handlers.presence_handler.get_state(self.u_potato, - self.u_apple) - - self.assertEquals( - {"presence": ONLINE, - "displayname": "Frank", - "avatar_url": "http://foo"}, - state) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 237fc8223c..a87703bbfd 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -41,8 +41,10 @@ class ProfileTestCase(unittest.TestCase): ]) self.query_handlers = {} + def register_query_handler(query_type, handler): self.query_handlers[query_type] = handler + self.mock_federation.register_query_handler = register_query_handler hs = yield setup_test_homeserver( @@ -63,16 +65,13 @@ class ProfileTestCase(unittest.TestCase): self.store = hs.get_datastore() self.frank = UserID.from_string("@1234ABCD:test") - self.bob = UserID.from_string("@4567:test") + self.bob = UserID.from_string("@4567:test") self.alice = UserID.from_string("@alice:remote") yield self.store.create_profile(self.frank.localpart) self.handler = hs.get_handlers().profile_handler - # TODO(paul): Icky signal declarings.. booo - hs.get_distributor().declare("changed_presencelike_data") - @defer.inlineCallbacks def test_get_my_name(self): yield self.store.set_profile_displayname( @@ -136,8 +135,9 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_avatar(self): - yield self.handler.set_avatar_url(self.frank, self.frank, - "http://my.server/pic.gif") + yield self.handler.set_avatar_url( + self.frank, self.frank, "http://my.server/pic.gif" + ) self.assertEquals( (yield self.store.get_profile_avatar_url(self.frank.localpart)), diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 763c04d667..3955e7f5b1 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -138,9 +138,9 @@ class TypingNotificationsTestCase(unittest.TestCase): self.room_member_handler.get_joined_rooms_for_user = get_joined_rooms_for_user @defer.inlineCallbacks - def fetch_room_distributions_into(room_id, localusers=None, - remotedomains=None, ignore_user=None): - + def fetch_room_distributions_into( + room_id, localusers=None, remotedomains=None, ignore_user=None + ): members = yield get_room_members(room_id) for member in members: if ignore_user is not None and member == ignore_user: @@ -153,7 +153,8 @@ class TypingNotificationsTestCase(unittest.TestCase): if remotedomains is not None: remotedomains.add(member.domain) self.room_member_handler.fetch_room_distributions_into = ( - fetch_room_distributions_into) + fetch_room_distributions_into + ) def check_joined_room(room_id, user_id): if user_id not in [u.to_string() for u in self.room_members]: @@ -207,9 +208,12 @@ class TypingNotificationsTestCase(unittest.TestCase): put_json = self.mock_http_client.put_json put_json.expect_call_and_return( - call("farm", + call( + "farm", path="/_matrix/federation/v1/send/1000000/", - data=_expect_edu("farm", "m.typing", + data=_expect_edu( + "farm", + "m.typing", content={ "room_id": self.room_id, "user_id": self.u_apple.to_string(), @@ -237,9 +241,12 @@ class TypingNotificationsTestCase(unittest.TestCase): self.assertEquals(self.event_source.get_current_key(), 0) - yield self.mock_federation_resource.trigger("PUT", + yield self.mock_federation_resource.trigger( + "PUT", "/_matrix/federation/v1/send/1000000/", - _make_edu_json("farm", "m.typing", + _make_edu_json( + "farm", + "m.typing", content={ "room_id": self.room_id, "user_id": self.u_onion.to_string(), @@ -257,16 +264,13 @@ class TypingNotificationsTestCase(unittest.TestCase): room_ids=[self.room_id], from_key=0 ) - self.assertEquals( - events[0], - [ - {"type": "m.typing", - "room_id": self.room_id, - "content": { - "user_ids": [self.u_onion.to_string()], - }}, - ] - ) + self.assertEquals(events[0], [{ + "type": "m.typing", + "room_id": self.room_id, + "content": { + "user_ids": [self.u_onion.to_string()], + }, + }]) @defer.inlineCallbacks def test_stopped_typing(self): @@ -274,9 +278,12 @@ class TypingNotificationsTestCase(unittest.TestCase): put_json = self.mock_http_client.put_json put_json.expect_call_and_return( - call("farm", + call( + "farm", path="/_matrix/federation/v1/send/1000000/", - data=_expect_edu("farm", "m.typing", + data=_expect_edu( + "farm", + "m.typing", content={ "room_id": self.room_id, "user_id": self.u_apple.to_string(), @@ -317,16 +324,13 @@ class TypingNotificationsTestCase(unittest.TestCase): room_ids=[self.room_id], from_key=0, ) - self.assertEquals( - events[0], - [ - {"type": "m.typing", - "room_id": self.room_id, - "content": { - "user_ids": [], - }}, - ] - ) + self.assertEquals(events[0], [{ + "type": "m.typing", + "room_id": self.room_id, + "content": { + "user_ids": [], + }, + }]) @defer.inlineCallbacks def test_typing_timeout(self): @@ -351,16 +355,13 @@ class TypingNotificationsTestCase(unittest.TestCase): room_ids=[self.room_id], from_key=0, ) - self.assertEquals( - events[0], - [ - {"type": "m.typing", - "room_id": self.room_id, - "content": { - "user_ids": [self.u_apple.to_string()], - }}, - ] - ) + self.assertEquals(events[0], [{ + "type": "m.typing", + "room_id": self.room_id, + "content": { + "user_ids": [self.u_apple.to_string()], + }, + }]) self.clock.advance_time(11) @@ -373,16 +374,13 @@ class TypingNotificationsTestCase(unittest.TestCase): room_ids=[self.room_id], from_key=1, ) - self.assertEquals( - events[0], - [ - {"type": "m.typing", - "room_id": self.room_id, - "content": { - "user_ids": [], - }}, - ] - ) + self.assertEquals(events[0], [{ + "type": "m.typing", + "room_id": self.room_id, + "content": { + "user_ids": [], + }, + }]) # SYN-230 - see if we can still set after timeout @@ -403,13 +401,10 @@ class TypingNotificationsTestCase(unittest.TestCase): room_ids=[self.room_id], from_key=0, ) - self.assertEquals( - events[0], - [ - {"type": "m.typing", - "room_id": self.room_id, - "content": { - "user_ids": [self.u_apple.to_string()], - }}, - ] - ) + self.assertEquals(events[0], [{ + "type": "m.typing", + "room_id": self.room_id, + "content": { + "user_ids": [self.u_apple.to_string()], + }, + }]) diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index f9e5e5af01..f3c1927ce1 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -61,6 +61,9 @@ class CounterMetricTestCase(unittest.TestCase): 'vector{method="PUT"} 1', ]) + # Check that passing too few values errors + self.assertRaises(ValueError, counter.inc) + class CallbackMetricTestCase(unittest.TestCase): diff --git a/tests/replication/__init__.py b/tests/replication/__init__.py new file mode 100644 index 0000000000..b7df13c9ee --- /dev/null +++ b/tests/replication/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py new file mode 100644 index 0000000000..38daaf87e2 --- /dev/null +++ b/tests/replication/test_resource.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.replication.resource import ReplicationResource +from synapse.types import Requester, UserID + +from twisted.internet import defer +from tests import unittest +from tests.utils import setup_test_homeserver +from mock import Mock, NonCallableMock +import json +import contextlib + + +class ReplicationResourceCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver( + "red", + http_client=None, + replication_layer=Mock(), + ratelimiter=NonCallableMock(spec_set=[ + "send_message", + ]), + ) + self.user = UserID.from_string("@seeing:red") + + self.hs.get_ratelimiter().send_message.return_value = (True, 0) + + self.resource = ReplicationResource(self.hs) + + @defer.inlineCallbacks + def test_streams(self): + # Passing "-1" returns the current stream positions + code, body = yield self.get(streams="-1") + self.assertEquals(code, 200) + self.assertEquals(body["streams"]["field_names"], ["name", "position"]) + position = body["streams"]["position"] + # Passing the current position returns an empty response after the + # timeout + get = self.get(streams=str(position), timeout="0") + self.hs.clock.advance_time_msec(1) + code, body = yield get + self.assertEquals(code, 200) + self.assertEquals(body, {}) + + @defer.inlineCallbacks + def test_events(self): + get = self.get(events="-1", timeout="0") + yield self.hs.get_handlers().room_creation_handler.create_room( + Requester(self.user, "", False), {} + ) + code, body = yield get + self.assertEquals(code, 200) + self.assertEquals(body["events"]["field_names"], [ + "position", "internal", "json" + ]) + + @defer.inlineCallbacks + def test_presence(self): + get = self.get(presence="-1") + yield self.hs.get_handlers().presence_handler.set_state( + self.user, {"presence": "online"} + ) + code, body = yield get + self.assertEquals(code, 200) + self.assertEquals(body["presence"]["field_names"], [ + "position", "user_id", "state", "last_active_ts", + "last_federation_update_ts", "last_user_sync_ts", + "status_msg", "currently_active", + ]) + + @defer.inlineCallbacks + def test_typing(self): + room_id = yield self.create_room() + get = self.get(typing="-1") + yield self.hs.get_handlers().typing_notification_handler.started_typing( + self.user, self.user, room_id, timeout=2 + ) + code, body = yield get + self.assertEquals(code, 200) + self.assertEquals(body["typing"]["field_names"], [ + "position", "room_id", "typing" + ]) + + @defer.inlineCallbacks + def test_receipts(self): + room_id = yield self.create_room() + event_id = yield self.send_text_message(room_id, "Hello, World") + get = self.get(receipts="-1") + yield self.hs.get_handlers().receipts_handler.received_client_receipt( + room_id, "m.read", self.user.to_string(), event_id + ) + code, body = yield get + self.assertEquals(code, 200) + self.assertEquals(body["receipts"]["field_names"], [ + "position", "room_id", "receipt_type", "user_id", "event_id", "data" + ]) + + def _test_timeout(stream): + """Check that a request for the given stream timesout""" + @defer.inlineCallbacks + def test_timeout(self): + get = self.get(**{stream: "-1", "timeout": "0"}) + self.hs.clock.advance_time_msec(1) + code, body = yield get + self.assertEquals(code, 200) + self.assertEquals(body, {}) + test_timeout.__name__ = "test_timeout_%s" % (stream) + return test_timeout + + test_timeout_events = _test_timeout("events") + test_timeout_presence = _test_timeout("presence") + test_timeout_typing = _test_timeout("typing") + test_timeout_receipts = _test_timeout("receipts") + test_timeout_user_account_data = _test_timeout("user_account_data") + test_timeout_room_account_data = _test_timeout("room_account_data") + test_timeout_tag_account_data = _test_timeout("tag_account_data") + test_timeout_backfill = _test_timeout("backfill") + + @defer.inlineCallbacks + def send_text_message(self, room_id, message): + handler = self.hs.get_handlers().message_handler + event = yield handler.create_and_send_nonmember_event({ + "type": "m.room.message", + "content": {"body": "message", "msgtype": "m.text"}, + "room_id": room_id, + "sender": self.user.to_string(), + }) + defer.returnValue(event.event_id) + + @defer.inlineCallbacks + def create_room(self): + result = yield self.hs.get_handlers().room_creation_handler.create_room( + Requester(self.user, "", False), {} + ) + defer.returnValue(result["room_id"]) + + @defer.inlineCallbacks + def get(self, **params): + request = NonCallableMock(spec_set=[ + "write", "finish", "setResponseCode", "setHeader", "args", + "method", "processing" + ]) + + request.method = "GET" + request.args = {k: [v] for k, v in params.items()} + + @contextlib.contextmanager + def processing(): + yield + request.processing = processing + + yield self.resource._async_render_GET(request) + self.assertTrue(request.finish.called) + + if request.setResponseCode.called: + response_code = request.setResponseCode.call_args[0][0] + else: + response_code = 200 + + response_json = "".join( + call[0][0] for call in request.write.call_args_list + ) + response_body = json.loads(response_json) + + defer.returnValue((response_code, response_body)) diff --git a/tests/rest/client/v1/__init__.py b/tests/rest/client/v1/__init__.py index d0e9399dda..bfebb0f644 100644 --- a/tests/rest/client/v1/__init__.py +++ b/tests/rest/client/v1/__init__.py @@ -12,4 +12,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py deleted file mode 100644 index 8d7cfd79ab..0000000000 --- a/tests/rest/client/v1/test_presence.py +++ /dev/null @@ -1,412 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests REST events for /presence paths.""" -from tests import unittest -from twisted.internet import defer - -from mock import Mock - -from ....utils import MockHttpResource, setup_test_homeserver - -from synapse.api.constants import PresenceState -from synapse.handlers.presence import PresenceHandler -from synapse.rest.client.v1 import presence -from synapse.rest.client.v1 import events -from synapse.types import Requester, UserID -from synapse.util.async import run_on_reactor - -from collections import namedtuple - - -OFFLINE = PresenceState.OFFLINE -UNAVAILABLE = PresenceState.UNAVAILABLE -ONLINE = PresenceState.ONLINE - - -myid = "@apple:test" -PATH_PREFIX = "/_matrix/client/api/v1" - - -class NullSource(object): - """This event source never yields any events and its token remains at - zero. It may be useful for unit-testing.""" - def __init__(self, hs): - pass - - def get_new_events( - self, - user, - from_key, - room_ids=None, - limit=None, - is_guest=None - ): - return defer.succeed(([], from_key)) - - def get_current_key(self, direction='f'): - return defer.succeed(0) - - def get_pagination_rows(self, user, pagination_config, key): - return defer.succeed(([], pagination_config.from_key)) - - -class JustPresenceHandlers(object): - def __init__(self, hs): - self.presence_handler = PresenceHandler(hs) - - -class PresenceStateTestCase(unittest.TestCase): - - @defer.inlineCallbacks - def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - hs = yield setup_test_homeserver( - datastore=Mock(spec=[ - "get_presence_state", - "set_presence_state", - "insert_client_ip", - ]), - http_client=None, - resource_for_client=self.mock_resource, - resource_for_federation=self.mock_resource, - ) - hs.handlers = JustPresenceHandlers(hs) - - self.datastore = hs.get_datastore() - self.datastore.get_app_service_by_token = Mock(return_value=None) - - def get_presence_list(*a, **kw): - return defer.succeed([]) - self.datastore.get_presence_list = get_presence_list - - def _get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(myid), - "token_id": 1, - "is_guest": False, - } - - hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token - - room_member_handler = hs.handlers.room_member_handler = Mock( - spec=[ - "get_joined_rooms_for_user", - ] - ) - - def get_rooms_for_user(user): - return defer.succeed([]) - room_member_handler.get_joined_rooms_for_user = get_rooms_for_user - - presence.register_servlets(hs, self.mock_resource) - - self.u_apple = UserID.from_string(myid) - - @defer.inlineCallbacks - def test_get_my_status(self): - mocked_get = self.datastore.get_presence_state - mocked_get.return_value = defer.succeed( - {"state": ONLINE, "status_msg": "Available"} - ) - - (code, response) = yield self.mock_resource.trigger("GET", - "/presence/%s/status" % (myid), None) - - self.assertEquals(200, code) - self.assertEquals( - {"presence": ONLINE, "status_msg": "Available"}, - response - ) - mocked_get.assert_called_with("apple") - - @defer.inlineCallbacks - def test_set_my_status(self): - mocked_set = self.datastore.set_presence_state - mocked_set.return_value = defer.succeed({"state": OFFLINE}) - - (code, response) = yield self.mock_resource.trigger("PUT", - "/presence/%s/status" % (myid), - '{"presence": "unavailable", "status_msg": "Away"}') - - self.assertEquals(200, code) - mocked_set.assert_called_with("apple", - {"state": UNAVAILABLE, "status_msg": "Away"} - ) - - -class PresenceListTestCase(unittest.TestCase): - - @defer.inlineCallbacks - def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - - hs = yield setup_test_homeserver( - datastore=Mock(spec=[ - "has_presence_state", - "get_presence_state", - "allow_presence_visible", - "is_presence_visible", - "add_presence_list_pending", - "set_presence_list_accepted", - "del_presence_list", - "get_presence_list", - "insert_client_ip", - ]), - http_client=None, - resource_for_client=self.mock_resource, - resource_for_federation=self.mock_resource, - ) - hs.handlers = JustPresenceHandlers(hs) - - self.datastore = hs.get_datastore() - self.datastore.get_app_service_by_token = Mock(return_value=None) - - def has_presence_state(user_localpart): - return defer.succeed( - user_localpart in ("apple", "banana",) - ) - self.datastore.has_presence_state = has_presence_state - - def _get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(myid), - "token_id": 1, - "is_guest": False, - } - - hs.handlers.room_member_handler = Mock( - spec=[ - "get_joined_rooms_for_user", - ] - ) - - hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token - - presence.register_servlets(hs, self.mock_resource) - - self.u_apple = UserID.from_string("@apple:test") - self.u_banana = UserID.from_string("@banana:test") - - @defer.inlineCallbacks - def test_get_my_list(self): - self.datastore.get_presence_list.return_value = defer.succeed( - [{"observed_user_id": "@banana:test", "accepted": True}], - ) - - (code, response) = yield self.mock_resource.trigger("GET", - "/presence/list/%s" % (myid), None) - - self.assertEquals(200, code) - self.assertEquals([ - {"user_id": "@banana:test", "presence": OFFLINE, "accepted": True}, - ], response) - - self.datastore.get_presence_list.assert_called_with( - "apple", accepted=True - ) - - @defer.inlineCallbacks - def test_invite(self): - self.datastore.add_presence_list_pending.return_value = ( - defer.succeed(()) - ) - self.datastore.is_presence_visible.return_value = defer.succeed( - True - ) - - (code, response) = yield self.mock_resource.trigger("POST", - "/presence/list/%s" % (myid), - """{"invite": ["@banana:test"]}""" - ) - - self.assertEquals(200, code) - - self.datastore.add_presence_list_pending.assert_called_with( - "apple", "@banana:test" - ) - self.datastore.set_presence_list_accepted.assert_called_with( - "apple", "@banana:test" - ) - - @defer.inlineCallbacks - def test_drop(self): - self.datastore.del_presence_list.return_value = ( - defer.succeed(()) - ) - - (code, response) = yield self.mock_resource.trigger("POST", - "/presence/list/%s" % (myid), - """{"drop": ["@banana:test"]}""" - ) - - self.assertEquals(200, code) - - self.datastore.del_presence_list.assert_called_with( - "apple", "@banana:test" - ) - - -class PresenceEventStreamTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - - # HIDEOUS HACKERY - # TODO(paul): This should be injected in via the HomeServer DI system - from synapse.streams.events import ( - PresenceEventSource, EventSources - ) - - old_SOURCE_TYPES = EventSources.SOURCE_TYPES - def tearDown(): - EventSources.SOURCE_TYPES = old_SOURCE_TYPES - self.tearDown = tearDown - - EventSources.SOURCE_TYPES = { - k: NullSource for k in old_SOURCE_TYPES.keys() - } - EventSources.SOURCE_TYPES["presence"] = PresenceEventSource - - clock = Mock(spec=[ - "call_later", - "cancel_call_later", - "time_msec", - "looping_call", - ]) - - clock.time_msec.return_value = 1000000 - - hs = yield setup_test_homeserver( - http_client=None, - resource_for_client=self.mock_resource, - resource_for_federation=self.mock_resource, - datastore=Mock(spec=[ - "set_presence_state", - "get_presence_list", - "get_rooms_for_user", - ]), - clock=clock, - ) - - def _get_user_by_req(req=None, allow_guest=False): - return Requester(UserID.from_string(myid), "", False) - - hs.get_v1auth().get_user_by_req = _get_user_by_req - - presence.register_servlets(hs, self.mock_resource) - events.register_servlets(hs, self.mock_resource) - - hs.handlers.room_member_handler = Mock(spec=[]) - - self.room_members = [] - - def get_rooms_for_user(user): - if user in self.room_members: - return ["a-room"] - else: - return [] - hs.handlers.room_member_handler.get_joined_rooms_for_user = get_rooms_for_user - hs.handlers.room_member_handler.get_room_members = ( - lambda r: self.room_members if r == "a-room" else [] - ) - hs.handlers.room_member_handler._filter_events_for_client = ( - lambda user_id, events, **kwargs: events - ) - - self.mock_datastore = hs.get_datastore() - self.mock_datastore.get_app_service_by_token = Mock(return_value=None) - self.mock_datastore.get_app_service_by_user_id = Mock( - return_value=defer.succeed(None) - ) - self.mock_datastore.get_rooms_for_user = ( - lambda u: [ - namedtuple("Room", "room_id")(r) - for r in get_rooms_for_user(UserID.from_string(u)) - ] - ) - - def get_profile_displayname(user_id): - return defer.succeed("Frank") - self.mock_datastore.get_profile_displayname = get_profile_displayname - - def get_profile_avatar_url(user_id): - return defer.succeed(None) - self.mock_datastore.get_profile_avatar_url = get_profile_avatar_url - - def user_rooms_intersect(user_list): - room_member_ids = map(lambda u: u.to_string(), self.room_members) - - shared = all(map(lambda i: i in room_member_ids, user_list)) - return defer.succeed(shared) - self.mock_datastore.user_rooms_intersect = user_rooms_intersect - - def get_joined_hosts_for_room(room_id): - return [] - self.mock_datastore.get_joined_hosts_for_room = get_joined_hosts_for_room - - self.presence = hs.get_handlers().presence_handler - - self.u_apple = UserID.from_string("@apple:test") - self.u_banana = UserID.from_string("@banana:test") - - @defer.inlineCallbacks - def test_shortpoll(self): - self.room_members = [self.u_apple, self.u_banana] - - self.mock_datastore.set_presence_state.return_value = defer.succeed( - {"state": ONLINE} - ) - self.mock_datastore.get_presence_list.return_value = defer.succeed( - [] - ) - - (code, response) = yield self.mock_resource.trigger("GET", - "/events?timeout=0", None) - - self.assertEquals(200, code) - - # We've forced there to be only one data stream so the tokens will - # all be ours - - # I'll already get my own presence state change - self.assertEquals({"start": "0_1_0_0_0", "end": "0_1_0_0_0", "chunk": []}, - response - ) - - self.mock_datastore.set_presence_state.return_value = defer.succeed( - {"state": ONLINE} - ) - self.mock_datastore.get_presence_list.return_value = defer.succeed([]) - - yield self.presence.set_state(self.u_banana, self.u_banana, - state={"presence": ONLINE} - ) - - yield run_on_reactor() - - (code, response) = yield self.mock_resource.trigger("GET", - "/events?from=s0_1_0&timeout=0", None) - - self.assertEquals(200, code) - self.assertEquals({"start": "s0_1_0_0_0", "end": "s0_2_0_0_0", "chunk": [ - {"type": "m.presence", - "content": { - "user_id": "@banana:test", - "presence": ONLINE, - "displayname": "Frank", - "last_active_ago": 0, - }}, - ]}, response) diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index c1a3f52043..0785965de2 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -65,8 +65,9 @@ class ProfileTestCase(unittest.TestCase): mocked_get = self.mock_handler.get_displayname mocked_get.return_value = defer.succeed("Frank") - (code, response) = yield self.mock_resource.trigger("GET", - "/profile/%s/displayname" % (myid), None) + (code, response) = yield self.mock_resource.trigger( + "GET", "/profile/%s/displayname" % (myid), None + ) self.assertEquals(200, code) self.assertEquals({"displayname": "Frank"}, response) @@ -77,9 +78,11 @@ class ProfileTestCase(unittest.TestCase): mocked_set = self.mock_handler.set_displayname mocked_set.return_value = defer.succeed(()) - (code, response) = yield self.mock_resource.trigger("PUT", - "/profile/%s/displayname" % (myid), - '{"displayname": "Frank Jr."}') + (code, response) = yield self.mock_resource.trigger( + "PUT", + "/profile/%s/displayname" % (myid), + '{"displayname": "Frank Jr."}' + ) self.assertEquals(200, code) self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD") @@ -91,19 +94,23 @@ class ProfileTestCase(unittest.TestCase): mocked_set = self.mock_handler.set_displayname mocked_set.side_effect = AuthError(400, "message") - (code, response) = yield self.mock_resource.trigger("PUT", - "/profile/%s/displayname" % ("@4567:test"), '"Frank Jr."') + (code, response) = yield self.mock_resource.trigger( + "PUT", "/profile/%s/displayname" % ("@4567:test"), '"Frank Jr."' + ) - self.assertTrue(400 <= code < 499, - msg="code %d is in the 4xx range" % (code)) + self.assertTrue( + 400 <= code < 499, + msg="code %d is in the 4xx range" % (code) + ) @defer.inlineCallbacks def test_get_other_name(self): mocked_get = self.mock_handler.get_displayname mocked_get.return_value = defer.succeed("Bob") - (code, response) = yield self.mock_resource.trigger("GET", - "/profile/%s/displayname" % ("@opaque:elsewhere"), None) + (code, response) = yield self.mock_resource.trigger( + "GET", "/profile/%s/displayname" % ("@opaque:elsewhere"), None + ) self.assertEquals(200, code) self.assertEquals({"displayname": "Bob"}, response) @@ -113,19 +120,23 @@ class ProfileTestCase(unittest.TestCase): mocked_set = self.mock_handler.set_displayname mocked_set.side_effect = SynapseError(400, "message") - (code, response) = yield self.mock_resource.trigger("PUT", - "/profile/%s/displayname" % ("@opaque:elsewhere"), None) + (code, response) = yield self.mock_resource.trigger( + "PUT", "/profile/%s/displayname" % ("@opaque:elsewhere"), None + ) - self.assertTrue(400 <= code <= 499, - msg="code %d is in the 4xx range" % (code)) + self.assertTrue( + 400 <= code <= 499, + msg="code %d is in the 4xx range" % (code) + ) @defer.inlineCallbacks def test_get_my_avatar(self): mocked_get = self.mock_handler.get_avatar_url mocked_get.return_value = defer.succeed("http://my.server/me.png") - (code, response) = yield self.mock_resource.trigger("GET", - "/profile/%s/avatar_url" % (myid), None) + (code, response) = yield self.mock_resource.trigger( + "GET", "/profile/%s/avatar_url" % (myid), None + ) self.assertEquals(200, code) self.assertEquals({"avatar_url": "http://my.server/me.png"}, response) @@ -136,12 +147,13 @@ class ProfileTestCase(unittest.TestCase): mocked_set = self.mock_handler.set_avatar_url mocked_set.return_value = defer.succeed(()) - (code, response) = yield self.mock_resource.trigger("PUT", - "/profile/%s/avatar_url" % (myid), - '{"avatar_url": "http://my.server/pic.gif"}') + (code, response) = yield self.mock_resource.trigger( + "PUT", + "/profile/%s/avatar_url" % (myid), + '{"avatar_url": "http://my.server/pic.gif"}' + ) self.assertEquals(200, code) self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD") self.assertEquals(mocked_set.call_args[0][1].localpart, "1234ABCD") - self.assertEquals(mocked_set.call_args[0][2], - "http://my.server/pic.gif") + self.assertEquals(mocked_set.call_args[0][2], "http://my.server/pic.gif") diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index ad5dd3bd6e..afca5303ba 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -82,19 +82,22 @@ class RoomPermissionsTestCase(RestTestCase): is_public=True) # send a message in one of the rooms - self.created_rmid_msg_path = ("/rooms/%s/send/m.room.message/a1" % - (self.created_rmid)) + self.created_rmid_msg_path = ( + "/rooms/%s/send/m.room.message/a1" % (self.created_rmid) + ) (code, response) = yield self.mock_resource.trigger( - "PUT", - self.created_rmid_msg_path, - '{"msgtype":"m.text","body":"test msg"}') + "PUT", + self.created_rmid_msg_path, + '{"msgtype":"m.text","body":"test msg"}' + ) self.assertEquals(200, code, msg=str(response)) # set topic for public room (code, response) = yield self.mock_resource.trigger( - "PUT", - "/rooms/%s/state/m.room.topic" % self.created_public_rmid, - '{"topic":"Public Room Topic"}') + "PUT", + "/rooms/%s/state/m.room.topic" % self.created_public_rmid, + '{"topic":"Public Room Topic"}' + ) self.assertEquals(200, code, msg=str(response)) # auth as user_id now @@ -103,37 +106,6 @@ class RoomPermissionsTestCase(RestTestCase): def tearDown(self): pass -# @defer.inlineCallbacks -# def test_get_message(self): -# # get message in uncreated room, expect 403 -# (code, response) = yield self.mock_resource.trigger_get( -# "/rooms/noroom/messages/someid/m1") -# self.assertEquals(403, code, msg=str(response)) -# -# # get message in created room not joined (no state), expect 403 -# (code, response) = yield self.mock_resource.trigger_get( -# self.created_rmid_msg_path) -# self.assertEquals(403, code, msg=str(response)) -# -# # get message in created room and invited, expect 403 -# yield self.invite(room=self.created_rmid, src=self.rmcreator_id, -# targ=self.user_id) -# (code, response) = yield self.mock_resource.trigger_get( -# self.created_rmid_msg_path) -# self.assertEquals(403, code, msg=str(response)) -# -# # get message in created room and joined, expect 200 -# yield self.join(room=self.created_rmid, user=self.user_id) -# (code, response) = yield self.mock_resource.trigger_get( -# self.created_rmid_msg_path) -# self.assertEquals(200, code, msg=str(response)) -# -# # get message in created room and left, expect 403 -# yield self.leave(room=self.created_rmid, user=self.user_id) -# (code, response) = yield self.mock_resource.trigger_get( -# self.created_rmid_msg_path) -# self.assertEquals(403, code, msg=str(response)) - @defer.inlineCallbacks def test_send_message(self): msg_content = '{"msgtype":"m.text","body":"hello"}' @@ -195,25 +167,30 @@ class RoomPermissionsTestCase(RestTestCase): # set/get topic in uncreated room, expect 403 (code, response) = yield self.mock_resource.trigger( - "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, - topic_content) + "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, + topic_content + ) self.assertEquals(403, code, msg=str(response)) (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/state/m.room.topic" % self.uncreated_rmid) + "/rooms/%s/state/m.room.topic" % self.uncreated_rmid + ) self.assertEquals(403, code, msg=str(response)) # set/get topic in created PRIVATE room not joined, expect 403 (code, response) = yield self.mock_resource.trigger( - "PUT", topic_path, topic_content) + "PUT", topic_path, topic_content + ) self.assertEquals(403, code, msg=str(response)) (code, response) = yield self.mock_resource.trigger_get(topic_path) self.assertEquals(403, code, msg=str(response)) # set topic in created PRIVATE room and invited, expect 403 - yield self.invite(room=self.created_rmid, src=self.rmcreator_id, - targ=self.user_id) + yield self.invite( + room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id + ) (code, response) = yield self.mock_resource.trigger( - "PUT", topic_path, topic_content) + "PUT", topic_path, topic_content + ) self.assertEquals(403, code, msg=str(response)) # get topic in created PRIVATE room and invited, expect 403 @@ -226,7 +203,8 @@ class RoomPermissionsTestCase(RestTestCase): # Only room ops can set topic by default self.auth_user_id = self.rmcreator_id (code, response) = yield self.mock_resource.trigger( - "PUT", topic_path, topic_content) + "PUT", topic_path, topic_content + ) self.assertEquals(200, code, msg=str(response)) self.auth_user_id = self.user_id @@ -237,30 +215,31 @@ class RoomPermissionsTestCase(RestTestCase): # set/get topic in created PRIVATE room and left, expect 403 yield self.leave(room=self.created_rmid, user=self.user_id) (code, response) = yield self.mock_resource.trigger( - "PUT", topic_path, topic_content) + "PUT", topic_path, topic_content + ) self.assertEquals(403, code, msg=str(response)) (code, response) = yield self.mock_resource.trigger_get(topic_path) self.assertEquals(200, code, msg=str(response)) # get topic in PUBLIC room, not joined, expect 403 (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/state/m.room.topic" % self.created_public_rmid) + "/rooms/%s/state/m.room.topic" % self.created_public_rmid + ) self.assertEquals(403, code, msg=str(response)) # set topic in PUBLIC room, not joined, expect 403 (code, response) = yield self.mock_resource.trigger( - "PUT", - "/rooms/%s/state/m.room.topic" % self.created_public_rmid, - topic_content) + "PUT", + "/rooms/%s/state/m.room.topic" % self.created_public_rmid, + topic_content + ) self.assertEquals(403, code, msg=str(response)) @defer.inlineCallbacks def _test_get_membership(self, room=None, members=[], expect_code=None): - path = "/rooms/%s/state/m.room.member/%s" for member in members: - (code, response) = yield self.mock_resource.trigger_get( - path % - (room, member)) + path = "/rooms/%s/state/m.room.member/%s" % (room, member) + (code, response) = yield self.mock_resource.trigger_get(path) self.assertEquals(expect_code, code) @defer.inlineCallbacks @@ -461,20 +440,23 @@ class RoomsMemberListTestCase(RestTestCase): def test_get_member_list(self): room_id = yield self.create_room_as(self.user_id) (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/members" % room_id) + "/rooms/%s/members" % room_id + ) self.assertEquals(200, code, msg=str(response)) @defer.inlineCallbacks def test_get_member_list_no_room(self): (code, response) = yield self.mock_resource.trigger_get( - "/rooms/roomdoesnotexist/members") + "/rooms/roomdoesnotexist/members" + ) self.assertEquals(403, code, msg=str(response)) @defer.inlineCallbacks def test_get_member_list_no_permission(self): room_id = yield self.create_room_as("@some_other_guy:red") (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/members" % room_id) + "/rooms/%s/members" % room_id + ) self.assertEquals(403, code, msg=str(response)) @defer.inlineCallbacks @@ -636,34 +618,41 @@ class RoomTopicTestCase(RestTestCase): @defer.inlineCallbacks def test_invalid_puts(self): # missing keys or invalid json - (code, response) = yield self.mock_resource.trigger("PUT", - self.path, '{}') + (code, response) = yield self.mock_resource.trigger( + "PUT", self.path, '{}' + ) self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - self.path, '{"_name":"bob"}') + (code, response) = yield self.mock_resource.trigger( + "PUT", self.path, '{"_name":"bob"}' + ) self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - self.path, '{"nao') + (code, response) = yield self.mock_resource.trigger( + "PUT", self.path, '{"nao' + ) self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - self.path, '[{"_name":"bob"},{"_name":"jill"}]') + (code, response) = yield self.mock_resource.trigger( + "PUT", self.path, '[{"_name":"bob"},{"_name":"jill"}]' + ) self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - self.path, 'text only') + (code, response) = yield self.mock_resource.trigger( + "PUT", self.path, 'text only' + ) self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - self.path, '') + (code, response) = yield self.mock_resource.trigger( + "PUT", self.path, '' + ) self.assertEquals(400, code, msg=str(response)) # valid key, wrong type content = '{"topic":["Topic name"]}' - (code, response) = yield self.mock_resource.trigger("PUT", - self.path, content) + (code, response) = yield self.mock_resource.trigger( + "PUT", self.path, content + ) self.assertEquals(400, code, msg=str(response)) @defer.inlineCallbacks @@ -674,8 +663,9 @@ class RoomTopicTestCase(RestTestCase): # valid put content = '{"topic":"Topic name"}' - (code, response) = yield self.mock_resource.trigger("PUT", - self.path, content) + (code, response) = yield self.mock_resource.trigger( + "PUT", self.path, content + ) self.assertEquals(200, code, msg=str(response)) # valid get @@ -687,8 +677,9 @@ class RoomTopicTestCase(RestTestCase): def test_rooms_topic_with_extra_keys(self): # valid put with extra keys content = '{"topic":"Seasons","subtopic":"Summer"}' - (code, response) = yield self.mock_resource.trigger("PUT", - self.path, content) + (code, response) = yield self.mock_resource.trigger( + "PUT", self.path, content + ) self.assertEquals(200, code, msg=str(response)) # valid get @@ -740,33 +731,38 @@ class RoomMemberStateTestCase(RestTestCase): def test_invalid_puts(self): path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) # missing keys or invalid json - (code, response) = yield self.mock_resource.trigger("PUT", - path, '{}') + (code, response) = yield self.mock_resource.trigger("PUT", path, '{}') self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - path, '{"_name":"bob"}') + (code, response) = yield self.mock_resource.trigger( + "PUT", path, '{"_name":"bob"}' + ) self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - path, '{"nao') + (code, response) = yield self.mock_resource.trigger( + "PUT", path, '{"nao' + ) self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - path, '[{"_name":"bob"},{"_name":"jill"}]') + (code, response) = yield self.mock_resource.trigger( + "PUT", path, '[{"_name":"bob"},{"_name":"jill"}]' + ) self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - path, 'text only') + (code, response) = yield self.mock_resource.trigger( + "PUT", path, 'text only' + ) self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - path, '') + (code, response) = yield self.mock_resource.trigger( + "PUT", path, '' + ) self.assertEquals(400, code, msg=str(response)) # valid keys, wrong types - content = ('{"membership":["%s","%s","%s"]}' % - (Membership.INVITE, Membership.JOIN, Membership.LEAVE)) + content = ('{"membership":["%s","%s","%s"]}' % ( + Membership.INVITE, Membership.JOIN, Membership.LEAVE + )) (code, response) = yield self.mock_resource.trigger("PUT", path, content) self.assertEquals(400, code, msg=str(response)) @@ -813,8 +809,9 @@ class RoomMemberStateTestCase(RestTestCase): ) # valid invite message with custom key - content = ('{"membership":"%s","invite_text":"%s"}' % - (Membership.INVITE, "Join us!")) + content = ('{"membership":"%s","invite_text":"%s"}' % ( + Membership.INVITE, "Join us!" + )) (code, response) = yield self.mock_resource.trigger("PUT", path, content) self.assertEquals(200, code, msg=str(response)) @@ -867,28 +864,34 @@ class RoomMessagesTestCase(RestTestCase): path = "/rooms/%s/send/m.room.message/mid1" % ( urllib.quote(self.room_id)) # missing keys or invalid json - (code, response) = yield self.mock_resource.trigger("PUT", - path, '{}') + (code, response) = yield self.mock_resource.trigger( + "PUT", path, '{}' + ) self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - path, '{"_name":"bob"}') + (code, response) = yield self.mock_resource.trigger( + "PUT", path, '{"_name":"bob"}' + ) self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - path, '{"nao') + (code, response) = yield self.mock_resource.trigger( + "PUT", path, '{"nao' + ) self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - path, '[{"_name":"bob"},{"_name":"jill"}]') + (code, response) = yield self.mock_resource.trigger( + "PUT", path, '[{"_name":"bob"},{"_name":"jill"}]' + ) self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - path, 'text only') + (code, response) = yield self.mock_resource.trigger( + "PUT", path, 'text only' + ) self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - path, '') + (code, response) = yield self.mock_resource.trigger( + "PUT", path, '' + ) self.assertEquals(400, code, msg=str(response)) @defer.inlineCallbacks @@ -953,19 +956,14 @@ class RoomInitialSyncTestCase(RestTestCase): synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) - # Since I'm getting my own presence I need to exist as far as presence - # is concerned. - hs.get_handlers().presence_handler.registered_user( - UserID.from_string(self.user_id) - ) - # create the room self.room_id = yield self.create_room_as(self.user_id) @defer.inlineCallbacks def test_initial_sync(self): (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/initialSync" % self.room_id) + "/rooms/%s/initialSync" % self.room_id + ) self.assertEquals(200, code) self.assertEquals(self.room_id, response["room_id"]) @@ -989,8 +987,8 @@ class RoomInitialSyncTestCase(RestTestCase): self.assertTrue("presence" in response) - presence_by_user = {e["content"]["user_id"]: e - for e in response["presence"] + presence_by_user = { + e["content"]["user_id"]: e for e in response["presence"] } self.assertTrue(self.user_id in presence_by_user) self.assertEquals("m.presence", presence_by_user[self.user_id]["type"]) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index c4ac181a33..16d788ff61 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -81,9 +81,9 @@ class RoomTypingTestCase(RestTestCase): return defer.succeed([]) @defer.inlineCallbacks - def fetch_room_distributions_into(room_id, localusers=None, - remotedomains=None, ignore_user=None): - + def fetch_room_distributions_into( + room_id, localusers=None, remotedomains=None, ignore_user=None + ): members = yield get_room_members(room_id) for member in members: if ignore_user is not None and member == ignore_user: @@ -96,7 +96,8 @@ class RoomTypingTestCase(RestTestCase): if remotedomains is not None: remotedomains.add(member.domain) hs.get_handlers().room_member_handler.fetch_room_distributions_into = ( - fetch_room_distributions_into) + fetch_room_distributions_into + ) synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) @@ -109,8 +110,8 @@ class RoomTypingTestCase(RestTestCase): @defer.inlineCallbacks def test_set_typing(self): - (code, _) = yield self.mock_resource.trigger("PUT", - "/rooms/%s/typing/%s" % (self.room_id, self.user_id), + (code, _) = yield self.mock_resource.trigger( + "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), '{"typing": true, "timeout": 30000}' ) self.assertEquals(200, code) @@ -120,41 +121,38 @@ class RoomTypingTestCase(RestTestCase): from_key=0, room_ids=[self.room_id], ) - self.assertEquals( - events[0], - [ - {"type": "m.typing", - "room_id": self.room_id, - "content": { - "user_ids": [self.user_id], - }}, - ] - ) + self.assertEquals(events[0], [{ + "type": "m.typing", + "room_id": self.room_id, + "content": { + "user_ids": [self.user_id], + } + }]) @defer.inlineCallbacks def test_set_not_typing(self): - (code, _) = yield self.mock_resource.trigger("PUT", - "/rooms/%s/typing/%s" % (self.room_id, self.user_id), + (code, _) = yield self.mock_resource.trigger( + "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), '{"typing": false}' ) self.assertEquals(200, code) @defer.inlineCallbacks def test_typing_timeout(self): - (code, _) = yield self.mock_resource.trigger("PUT", - "/rooms/%s/typing/%s" % (self.room_id, self.user_id), + (code, _) = yield self.mock_resource.trigger( + "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), '{"typing": true, "timeout": 30000}' ) self.assertEquals(200, code) self.assertEquals(self.event_source.get_current_key(), 1) - self.clock.advance_time(31); + self.clock.advance_time(31) self.assertEquals(self.event_source.get_current_key(), 2) - (code, _) = yield self.mock_resource.trigger("PUT", - "/rooms/%s/typing/%s" % (self.room_id, self.user_id), + (code, _) = yield self.mock_resource.trigger( + "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), '{"typing": true, "timeout": 30000}' ) self.assertEquals(200, code) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index af376804f6..17524b2e23 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -84,8 +84,9 @@ class RestTestCase(unittest.TestCase): "membership": membership } - (code, response) = yield self.mock_resource.trigger("PUT", path, - json.dumps(data)) + (code, response) = yield self.mock_resource.trigger( + "PUT", path, json.dumps(data) + ) self.assertEquals(expect_code, code, msg=str(response)) self.auth_user_id = temp_id diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py index 16dce6c723..84334dce34 100644 --- a/tests/rest/client/v2_alpha/__init__.py +++ b/tests/rest/client/v2_alpha/__init__.py @@ -55,7 +55,7 @@ class V2AlphaRestTestCase(unittest.TestCase): r.register_servlets(hs, self.mock_resource) def make_datastore_mock(self): - store = Mock(spec=[ + store = Mock(spec=[ "insert_client_ip", ]) store.get_app_service_by_token = Mock(return_value=None) diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index c86e904c8e..d1442aafac 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -15,8 +15,6 @@ from twisted.internet import defer -from mock import Mock - from . import V2AlphaRestTestCase from synapse.rest.client.v2_alpha import filter @@ -53,9 +51,8 @@ class FilterTestCase(V2AlphaRestTestCase): @defer.inlineCallbacks def test_add_filter(self): - (code, response) = yield self.mock_resource.trigger("POST", - "/user/%s/filter" % (self.USER_ID), - '{"type": ["m.*"]}' + (code, response) = yield self.mock_resource.trigger( + "POST", "/user/%s/filter" % (self.USER_ID), '{"type": ["m.*"]}' ) self.assertEquals(200, code) self.assertEquals({"filter_id": "0"}, response) @@ -70,8 +67,8 @@ class FilterTestCase(V2AlphaRestTestCase): {"type": ["m.*"]} ] - (code, response) = yield self.mock_resource.trigger("GET", - "/user/%s/filter/0" % (self.USER_ID), None + (code, response) = yield self.mock_resource.trigger_get( + "/user/%s/filter/0" % (self.USER_ID) ) self.assertEquals(200, code) self.assertEquals({"type": ["m.*"]}, response) @@ -82,14 +79,14 @@ class FilterTestCase(V2AlphaRestTestCase): {"type": ["m.*"]} ] - (code, response) = yield self.mock_resource.trigger("GET", - "/user/%s/filter/2" % (self.USER_ID), None + (code, response) = yield self.mock_resource.trigger_get( + "/user/%s/filter/2" % (self.USER_ID) ) self.assertEquals(404, code) @defer.inlineCallbacks def test_get_filter_no_user(self): - (code, response) = yield self.mock_resource.trigger("GET", - "/user/%s/filter/0" % (self.USER_ID), None + (code, response) = yield self.mock_resource.trigger_get( + "/user/%s/filter/0" % (self.USER_ID) ) self.assertEquals(404, code) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index df0841b0b1..b867599079 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -1,7 +1,7 @@ from synapse.rest.client.v2_alpha.register import RegisterRestServlet from synapse.api.errors import SynapseError from twisted.internet import defer -from mock import Mock, MagicMock +from mock import Mock from tests import unittest import json @@ -24,7 +24,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.auth_result = (False, None, None) self.auth_handler = Mock( - check_auth=Mock(side_effect=lambda x,y,z: self.auth_result) + check_auth=Mock(side_effect=lambda x, y, z: self.auth_result) ) self.registration_handler = Mock() self.identity_handler = Mock() diff --git a/tests/storage/event_injector.py b/tests/storage/event_injector.py index dca785eb27..f22ba8db89 100644 --- a/tests/storage/event_injector.py +++ b/tests/storage/event_injector.py @@ -14,15 +14,9 @@ # limitations under the License. -from tests import unittest from twisted.internet import defer -from synapse.api.constants import EventTypes, Membership -from synapse.types import UserID, RoomID - -from tests.utils import setup_test_homeserver - -from mock import Mock +from synapse.api.constants import EventTypes class EventInjector: diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 219288621d..96b7dba5fe 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -174,11 +174,13 @@ class CacheDecoratorTestCase(unittest.TestCase): # There must have been at least 2 evictions, meaning if we calculate # all 12 values again, we must get called at least 2 more times - for k in range(0,12): + for k in range(0, 12): yield a.func(k) - self.assertTrue(callcount[0] >= 14, - msg="Expected callcount >= 14, got %d" % (callcount[0])) + self.assertTrue( + callcount[0] >= 14, + msg="Expected callcount >= 14, got %d" % (callcount[0]) + ) def test_prefill(self): callcount = [0] diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index ed8af10d87..5734198121 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -35,7 +35,8 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): def setUp(self): self.as_yaml_files = [] config = Mock( - app_service_config_files=self.as_yaml_files + app_service_config_files=self.as_yaml_files, + event_cache_size=1, ) hs = yield setup_test_homeserver(config=config) @@ -109,7 +110,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.as_yaml_files = [] config = Mock( - app_service_config_files=self.as_yaml_files + app_service_config_files=self.as_yaml_files, + event_cache_size=1, ) hs = yield setup_test_homeserver(config=config) self.db_pool = hs.get_db_pool() @@ -438,7 +440,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): f1 = self._write_config(suffix="1") f2 = self._write_config(suffix="2") - config = Mock(app_service_config_files=[f1, f2]) + config = Mock(app_service_config_files=[f1, f2], event_cache_size=1) hs = yield setup_test_homeserver(config=config, datastore=Mock()) ApplicationServiceStore(hs) @@ -448,7 +450,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): f1 = self._write_config(id="id", suffix="1") f2 = self._write_config(id="id", suffix="2") - config = Mock(app_service_config_files=[f1, f2]) + config = Mock(app_service_config_files=[f1, f2], event_cache_size=1) hs = yield setup_test_homeserver(config=config, datastore=Mock()) with self.assertRaises(ConfigError) as cm: @@ -464,7 +466,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): f1 = self._write_config(as_token="as_token", suffix="1") f2 = self._write_config(as_token="as_token", suffix="2") - config = Mock(app_service_config_files=[f1, f2]) + config = Mock(app_service_config_files=[f1, f2], event_cache_size=1) hs = yield setup_test_homeserver(config=config, datastore=Mock()) with self.assertRaises(ConfigError) as cm: diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 29289fa9b4..6e4d9b1373 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -1,13 +1,11 @@ from tests import unittest from twisted.internet import defer -from synapse.api.constants import EventTypes -from synapse.types import UserID, RoomID, RoomAlias - from tests.utils import setup_test_homeserver from mock import Mock + class BackgroundUpdateTestCase(unittest.TestCase): @defer.inlineCallbacks @@ -24,8 +22,8 @@ class BackgroundUpdateTestCase(unittest.TestCase): @defer.inlineCallbacks def test_do_background_update(self): - desired_count = 1000; - duration_ms = 42; + desired_count = 1000 + duration_ms = 42 @defer.inlineCallbacks def update(progress, count): diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 152d027663..c76545be65 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -17,7 +17,7 @@ from tests import unittest from twisted.internet import defer -from mock import Mock, call +from mock import Mock from collections import OrderedDict @@ -62,13 +62,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 yield self.datastore._simple_insert( - table="tablename", - values={"columname": "Value"} + table="tablename", + values={"columname": "Value"} ) self.mock_txn.execute.assert_called_with( - "INSERT INTO tablename (columname) VALUES(?)", - ("Value",) + "INSERT INTO tablename (columname) VALUES(?)", ("Value",) ) @defer.inlineCallbacks @@ -76,14 +75,14 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 yield self.datastore._simple_insert( - table="tablename", - # Use OrderedDict() so we can assert on the SQL generated - values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]) + table="tablename", + # Use OrderedDict() so we can assert on the SQL generated + values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]) ) self.mock_txn.execute.assert_called_with( - "INSERT INTO tablename (colA, colB, colC) VALUES(?, ?, ?)", - (1, 2, 3,) + "INSERT INTO tablename (colA, colB, colC) VALUES(?, ?, ?)", + (1, 2, 3,) ) @defer.inlineCallbacks @@ -92,15 +91,14 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.fetchall.return_value = [("Value",)] value = yield self.datastore._simple_select_one_onecol( - table="tablename", - keyvalues={"keycol": "TheKey"}, - retcol="retcol" + table="tablename", + keyvalues={"keycol": "TheKey"}, + retcol="retcol" ) self.assertEquals("Value", value) self.mock_txn.execute.assert_called_with( - "SELECT retcol FROM tablename WHERE keycol = ?", - ["TheKey"] + "SELECT retcol FROM tablename WHERE keycol = ?", ["TheKey"] ) @defer.inlineCallbacks @@ -109,15 +107,15 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.fetchone.return_value = (1, 2, 3) ret = yield self.datastore._simple_select_one( - table="tablename", - keyvalues={"keycol": "TheKey"}, - retcols=["colA", "colB", "colC"] + table="tablename", + keyvalues={"keycol": "TheKey"}, + retcols=["colA", "colB", "colC"] ) self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret) self.mock_txn.execute.assert_called_with( - "SELECT colA, colB, colC FROM tablename WHERE keycol = ?", - ["TheKey"] + "SELECT colA, colB, colC FROM tablename WHERE keycol = ?", + ["TheKey"] ) @defer.inlineCallbacks @@ -126,32 +124,32 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.fetchone.return_value = None ret = yield self.datastore._simple_select_one( - table="tablename", - keyvalues={"keycol": "Not here"}, - retcols=["colA"], - allow_none=True + table="tablename", + keyvalues={"keycol": "Not here"}, + retcols=["colA"], + allow_none=True ) self.assertFalse(ret) @defer.inlineCallbacks def test_select_list(self): - self.mock_txn.rowcount = 3; + self.mock_txn.rowcount = 3 self.mock_txn.fetchall.return_value = ((1,), (2,), (3,)) self.mock_txn.description = ( - ("colA", None, None, None, None, None, None), + ("colA", None, None, None, None, None, None), ) ret = yield self.datastore._simple_select_list( - table="tablename", - keyvalues={"keycol": "A set"}, - retcols=["colA"], + table="tablename", + keyvalues={"keycol": "A set"}, + retcols=["colA"], ) self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret) self.mock_txn.execute.assert_called_with( - "SELECT colA FROM tablename WHERE keycol = ?", - ["A set"] + "SELECT colA FROM tablename WHERE keycol = ?", + ["A set"] ) @defer.inlineCallbacks @@ -159,14 +157,14 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 yield self.datastore._simple_update_one( - table="tablename", - keyvalues={"keycol": "TheKey"}, - updatevalues={"columnname": "New Value"} + table="tablename", + keyvalues={"keycol": "TheKey"}, + updatevalues={"columnname": "New Value"} ) self.mock_txn.execute.assert_called_with( - "UPDATE tablename SET columnname = ? WHERE keycol = ?", - ["New Value", "TheKey"] + "UPDATE tablename SET columnname = ? WHERE keycol = ?", + ["New Value", "TheKey"] ) @defer.inlineCallbacks @@ -174,15 +172,15 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 yield self.datastore._simple_update_one( - table="tablename", - keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), - updatevalues=OrderedDict([("colC", 3), ("colD", 4)]) + table="tablename", + keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), + updatevalues=OrderedDict([("colC", 3), ("colD", 4)]) ) self.mock_txn.execute.assert_called_with( - "UPDATE tablename SET colC = ?, colD = ? WHERE " + - "colA = ? AND colB = ?", - [3, 4, 1, 2] + "UPDATE tablename SET colC = ?, colD = ? WHERE" + " colA = ? AND colB = ?", + [3, 4, 1, 2] ) @defer.inlineCallbacks @@ -190,11 +188,10 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 yield self.datastore._simple_delete_one( - table="tablename", - keyvalues={"keycol": "Go away"}, + table="tablename", + keyvalues={"keycol": "Go away"}, ) self.mock_txn.execute.assert_called_with( - "DELETE FROM tablename WHERE keycol = ?", - ["Go away"] + "DELETE FROM tablename WHERE keycol = ?", ["Go away"] ) diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index 4aa82d4c9d..18a6cff0c7 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import uuid from mock import Mock from synapse.types import RoomID, UserID diff --git a/tests/storage/test_presence.py b/tests/storage/test_presence.py index 333f1e10f1..ec78f007ca 100644 --- a/tests/storage/test_presence.py +++ b/tests/storage/test_presence.py @@ -35,32 +35,6 @@ class PresenceStoreTestCase(unittest.TestCase): self.u_banana = UserID.from_string("@banana:test") @defer.inlineCallbacks - def test_state(self): - yield self.store.create_presence( - self.u_apple.localpart - ) - - state = yield self.store.get_presence_state( - self.u_apple.localpart - ) - - self.assertEquals( - {"state": None, "status_msg": None, "mtime": None}, state - ) - - yield self.store.set_presence_state( - self.u_apple.localpart, {"state": "online", "status_msg": "Here"} - ) - - state = yield self.store.get_presence_state( - self.u_apple.localpart - ) - - self.assertEquals( - {"state": "online", "status_msg": "Here", "mtime": 1000000}, state - ) - - @defer.inlineCallbacks def test_visibility(self): self.assertFalse((yield self.store.is_presence_visible( observed_localpart=self.u_apple.localpart, diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 47e2768b2c..24118bbc86 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -55,7 +55,7 @@ class ProfileStoreTestCase(unittest.TestCase): ) yield self.store.set_profile_avatar_url( - self.u_frank.localpart, "http://my.site/here" + self.u_frank.localpart, "http://my.site/here" ) self.assertEquals( diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 7b3b4c13bc..b8384c98d8 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -33,8 +33,10 @@ class RegistrationStoreTestCase(unittest.TestCase): self.store = hs.get_datastore() self.user_id = "@my-user:test" - self.tokens = ["AbCdEfGhIjKlMnOpQrStUvWxYz", - "BcDeFgHiJkLmNoPqRsTuVwXyZa"] + self.tokens = [ + "AbCdEfGhIjKlMnOpQrStUvWxYz", + "BcDeFgHiJkLmNoPqRsTuVwXyZa" + ] self.pwhash = "{xx1}123456789" @defer.inlineCallbacks diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 0baaf3df21..ef8a4d234f 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -37,7 +37,8 @@ class RoomStoreTestCase(unittest.TestCase): self.alias = RoomAlias.from_string("#a-room-name:test") self.u_creator = UserID.from_string("@creator:test") - yield self.store.store_room(self.room.to_string(), + yield self.store.store_room( + self.room.to_string(), room_creator_user_id=self.u_creator.to_string(), is_public=True ) @@ -45,9 +46,11 @@ class RoomStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_room(self): self.assertDictContainsSubset( - {"room_id": self.room.to_string(), - "creator": self.u_creator.to_string(), - "is_public": True}, + { + "room_id": self.room.to_string(), + "creator": self.u_creator.to_string(), + "is_public": True + }, (yield self.store.get_room(self.room.to_string())) ) @@ -65,7 +68,8 @@ class RoomEventsStoreTestCase(unittest.TestCase): self.room = RoomID.from_string("!abcde:test") - yield self.store.store_room(self.room.to_string(), + yield self.store.store_room( + self.room.to_string(), room_creator_user_id="@creator:text", is_public=True ) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index bab15c4165..677d11f68d 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -88,8 +88,8 @@ class RoomMemberStoreTestCase(unittest.TestCase): [m.room_id for m in ( yield self.store.get_rooms_for_user_where_membership_is( self.u_alice.to_string(), [Membership.JOIN] - )) - ] + ) + )] ) self.assertFalse( (yield self.store.user_rooms_intersect( @@ -108,11 +108,11 @@ class RoomMemberStoreTestCase(unittest.TestCase): yield self.store.get_room_members(self.room.to_string()) )} ) - self.assertTrue( - (yield self.store.user_rooms_intersect( - [self.u_alice.to_string(), self.u_bob.to_string()] - )) - ) + self.assertTrue(( + yield self.store.user_rooms_intersect([ + self.u_alice.to_string(), self.u_bob.to_string() + ]) + )) @defer.inlineCallbacks def test_room_hosts(self): @@ -136,9 +136,7 @@ class RoomMemberStoreTestCase(unittest.TestCase): self.assertEquals( {"test", "elsewhere"}, - (yield - self.store.get_joined_hosts_for_room(self.room.to_string()) - ) + (yield self.store.get_joined_hosts_for_room(self.room.to_string())) ) # Should still have both hosts @@ -146,9 +144,7 @@ class RoomMemberStoreTestCase(unittest.TestCase): self.assertEquals( {"test", "elsewhere"}, - (yield - self.store.get_joined_hosts_for_room(self.room.to_string()) - ) + (yield self.store.get_joined_hosts_for_room(self.room.to_string())) ) # Should have only one host after other leaves diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index 708208aff1..da322152c7 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -156,13 +156,13 @@ class StreamStoreTestCase(unittest.TestCase): self.room1, self.u_bob, Membership.JOIN ) - event1 = yield self.event_injector.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_alice, Membership.JOIN ) start = yield self.store.get_room_events_max_id() - event2 = yield self.event_injector.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_alice, Membership.JOIN, ) diff --git a/tests/test_distributor.py b/tests/test_distributor.py index a80f580ba6..acebcf4a86 100644 --- a/tests/test_distributor.py +++ b/tests/test_distributor.py @@ -44,8 +44,10 @@ class DistributorTestCase(unittest.TestCase): self.dist.declare("whine") d_inner = defer.Deferred() + def observer(): return d_inner + self.dist.observe("whine", observer) d_outer = self.dist.fire("whine") @@ -66,8 +68,8 @@ class DistributorTestCase(unittest.TestCase): observers[0].side_effect = Exception("Awoogah!") - with patch("synapse.util.distributor.logger", - spec=["warning"] + with patch( + "synapse.util.distributor.logger", spec=["warning"] ) as mock_logger: d = self.dist.fire("alarm", "Go") yield d @@ -77,8 +79,9 @@ class DistributorTestCase(unittest.TestCase): observers[1].assert_called_once_with("Go") self.assertEquals(mock_logger.warning.call_count, 1) - self.assertIsInstance(mock_logger.warning.call_args[0][0], - str) + self.assertIsInstance( + mock_logger.warning.call_args[0][0], str + ) @defer.inlineCallbacks def test_signal_catch_no_suppress(self): diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py index 3718f4fc3f..d28bb726bb 100644 --- a/tests/test_test_utils.py +++ b/tests/test_test_utils.py @@ -17,6 +17,7 @@ from tests import unittest from tests.utils import MockClock + class MockClockTestCase(unittest.TestCase): def setUp(self): @@ -60,7 +61,7 @@ class MockClockTestCase(unittest.TestCase): def _cb1(): invoked[1] = 1 - t1 = self.clock.call_later(20, _cb1) + self.clock.call_later(20, _cb1) self.clock.cancel_call_later(t0) diff --git a/tests/unittest.py b/tests/unittest.py index 6f02eb4cac..5b22abfe74 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -37,9 +37,12 @@ def around(target): def _around(code): name = code.__name__ orig = getattr(target, name) + def new(*args, **kwargs): return code(orig, *args, **kwargs) + setattr(target, name, new) + return _around @@ -53,9 +56,7 @@ class TestCase(unittest.TestCase): method = getattr(self, methodName) - level = getattr(method, "loglevel", - getattr(self, "loglevel", - NEVER)) + level = getattr(method, "loglevel", getattr(self, "loglevel", NEVER)) @around(self) def setUp(orig): diff --git a/tests/util/__init__.py b/tests/util/__init__.py index d0e9399dda..bfebb0f644 100644 --- a/tests/util/__init__.py +++ b/tests/util/__init__.py @@ -12,4 +12,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py index 7bbe795622..272b71034a 100644 --- a/tests/util/test_dict_cache.py +++ b/tests/util/test_dict_cache.py @@ -14,7 +14,6 @@ # limitations under the License. -from twisted.internet import defer from tests import unittest from synapse.util.caches.dictionary_cache import DictionaryCache diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py index 4ee0d49673..7e289715ba 100644 --- a/tests/util/test_snapshot_cache.py +++ b/tests/util/test_snapshot_cache.py @@ -19,6 +19,7 @@ from .. import unittest from synapse.util.caches.snapshot_cache import SnapshotCache from twisted.internet.defer import Deferred + class SnapshotCacheTestCase(unittest.TestCase): def setUp(self): diff --git a/tests/util/test_treecache.py b/tests/util/test_treecache.py index 1efbeb6b33..7ab578a185 100644 --- a/tests/util/test_treecache.py +++ b/tests/util/test_treecache.py @@ -18,6 +18,7 @@ from .. import unittest from synapse.util.caches.treecache import TreeCache + class TreeCacheTestCase(unittest.TestCase): def test_get_set_onelevel(self): cache = TreeCache() @@ -76,3 +77,9 @@ class TreeCacheTestCase(unittest.TestCase): cache[("b",)] = "B" cache.clear() self.assertEquals(len(cache), 0) + + def test_contains(self): + cache = TreeCache() + cache[("a",)] = "A" + self.assertTrue(("a",) in cache) + self.assertFalse(("b",) in cache) diff --git a/tests/util/test_wheel_timer.py b/tests/util/test_wheel_timer.py new file mode 100644 index 0000000000..c44567e52e --- /dev/null +++ b/tests/util/test_wheel_timer.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .. import unittest + +from synapse.util.wheel_timer import WheelTimer + + +class WheelTimerTestCase(unittest.TestCase): + def test_single_insert_fetch(self): + wheel = WheelTimer(bucket_size=5) + + obj = object() + wheel.insert(100, obj, 150) + + self.assertListEqual(wheel.fetch(101), []) + self.assertListEqual(wheel.fetch(110), []) + self.assertListEqual(wheel.fetch(120), []) + self.assertListEqual(wheel.fetch(130), []) + self.assertListEqual(wheel.fetch(149), []) + self.assertListEqual(wheel.fetch(156), [obj]) + self.assertListEqual(wheel.fetch(170), []) + + def test_mutli_insert(self): + wheel = WheelTimer(bucket_size=5) + + obj1 = object() + obj2 = object() + obj3 = object() + wheel.insert(100, obj1, 150) + wheel.insert(105, obj2, 130) + wheel.insert(106, obj3, 160) + + self.assertListEqual(wheel.fetch(110), []) + self.assertListEqual(wheel.fetch(135), [obj2]) + self.assertListEqual(wheel.fetch(149), []) + self.assertListEqual(wheel.fetch(158), [obj1]) + self.assertListEqual(wheel.fetch(160), []) + self.assertListEqual(wheel.fetch(200), [obj3]) + self.assertListEqual(wheel.fetch(210), []) + + def test_insert_past(self): + wheel = WheelTimer(bucket_size=5) + + obj = object() + wheel.insert(100, obj, 50) + self.assertListEqual(wheel.fetch(120), [obj]) + + def test_insert_past_mutli(self): + wheel = WheelTimer(bucket_size=5) + + obj1 = object() + obj2 = object() + obj3 = object() + wheel.insert(100, obj1, 150) + wheel.insert(100, obj2, 140) + wheel.insert(100, obj3, 50) + self.assertListEqual(wheel.fetch(110), [obj3]) + self.assertListEqual(wheel.fetch(120), []) + self.assertListEqual(wheel.fetch(147), [obj2]) + self.assertListEqual(wheel.fetch(200), [obj1]) + self.assertListEqual(wheel.fetch(240), []) diff --git a/tests/utils.py b/tests/utils.py index 3b1eb50d8d..dfbee5c23a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -152,7 +152,7 @@ class MockHttpResource(HttpServer): mock_request.getClientIP.return_value = "-" - mock_request.requestHeaders.getRawHeaders.return_value=[ + mock_request.requestHeaders.getRawHeaders.return_value = [ "X-Matrix origin=test,key=,sig=" ] @@ -224,12 +224,12 @@ class MockClock(object): def time_msec(self): return self.time() * 1000 - def call_later(self, delay, callback): + def call_later(self, delay, callback, *args, **kwargs): current_context = LoggingContext.current_context() def wrapped_callback(): LoggingContext.thread_local.current_context = current_context - callback() + callback(*args, **kwargs) t = [self.now + delay, wrapped_callback, False] self.timers.append(t) @@ -239,9 +239,10 @@ class MockClock(object): def looping_call(self, function, interval): pass - def cancel_call_later(self, timer): + def cancel_call_later(self, timer, ignore_errs=False): if timer[2]: - raise Exception("Cannot cancel an expired timer") + if not ignore_errs: + raise Exception("Cannot cancel an expired timer") timer[2] = True self.timers = [t for t in self.timers if t != timer] @@ -360,13 +361,12 @@ class MemoryDataStore(object): def get_rooms_for_user_where_membership_is(self, user_id, membership_list): return [ - self.members[r].get(user_id) for r in self.members - if user_id in self.members[r] and - self.members[r][user_id].membership in membership_list + m[user_id] for m in self.members.values() + if user_id in m and m[user_id].membership in membership_list ] def get_room_events_stream(self, user_id=None, from_key=None, to_key=None, - limit=0, with_feedback=False): + limit=0, with_feedback=False): return ([], from_key) # TODO def get_joined_hosts_for_room(self, room_id): @@ -376,7 +376,6 @@ class MemoryDataStore(object): if event.type == EventTypes.Member: room_id = event.room_id user = event.state_key - membership = event.membership self.members.setdefault(room_id, {})[user] = event if hasattr(event, "state_key"): @@ -456,9 +455,9 @@ class DeferredMockCallable(object): d.callback(None) return result - failure = AssertionError("Was not expecting call(%s)" % + failure = AssertionError("Was not expecting call(%s)" % ( _format_call(args, kwargs) - ) + )) for _, _, d in self.expectations: try: @@ -479,14 +478,12 @@ class DeferredMockCallable(object): ) timer = reactor.callLater( - timeout/1000, + timeout / 1000, deferred.errback, - AssertionError( - "%d pending calls left: %s"% ( - len([e for e in self.expectations if not e[2].called]), - [e for e in self.expectations if not e[2].called] - ) - ) + AssertionError("%d pending calls left: %s" % ( + len([e for e in self.expectations if not e[2].called]), + [e for e in self.expectations if not e[2].called] + )) ) yield deferred @@ -500,8 +497,8 @@ class DeferredMockCallable(object): calls = self.calls self.calls = [] - raise AssertionError("Expected not to received any calls, got:\n" + - "\n".join([ + raise AssertionError( + "Expected not to received any calls, got:\n" + "\n".join([ "call(%s)" % _format_call(c[0], c[1]) for c in calls ]) ) diff --git a/tox.ini b/tox.ini index 6f01599af7..757b7189c3 100644 --- a/tox.ini +++ b/tox.ini @@ -26,4 +26,4 @@ skip_install = True basepython = python2.7 deps = flake8 -commands = /bin/bash -c "flake8 synapse {env:PEP8SUFFIX:}" +commands = /bin/bash -c "flake8 synapse tests {env:PEP8SUFFIX:}" |