From 763360594dfb90433f693056d3d64ac82409fc87 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Thu, 11 Feb 2016 14:10:00 +0000 Subject: Mark AS users with their AS's ID --- tests/utils.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) (limited to 'tests/utils.py') diff --git a/tests/utils.py b/tests/utils.py index 3b1eb50d8d..f3935648a0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -51,6 +51,8 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): config.server_name = "server.under.test" config.trusted_third_party_id_servers = [] + config.database_config = {"name": "sqlite3"} + if "clock" not in kargs: kargs["clock"] = MockClock() @@ -60,7 +62,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): hs = HomeServer( name, db_pool=db_pool, config=config, version_string="Synapse/tests", - database_engine=create_engine("sqlite3"), + database_engine=create_engine(config), get_db_conn=db_pool.get_db_conn, **kargs ) @@ -69,7 +71,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): hs = HomeServer( name, db_pool=None, datastore=datastore, config=config, version_string="Synapse/tests", - database_engine=create_engine("sqlite3"), + database_engine=create_engine(config), **kargs ) @@ -277,18 +279,24 @@ class SQLiteMemoryDbPool(ConnectionPool, object): cp_max=1, ) + self.config = Mock() + self.config.database_config = {"name": "sqlite3"} + def prepare(self): - engine = create_engine("sqlite3") + engine = self.create_engine() return self.runWithConnection( - lambda conn: prepare_database(conn, engine) + lambda conn: prepare_database(conn, engine, self.config) ) def get_db_conn(self): conn = self.connect() - engine = create_engine("sqlite3") - prepare_database(conn, engine) + engine = self.create_engine() + prepare_database(conn, engine, self.config) return conn + def create_engine(self): + return create_engine(self.config) + class MemoryDataStore(object): -- cgit 1.4.1 From e5999bfb1a4aab56acecb59ed6d068442f5b11a0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 15 Feb 2016 17:10:40 +0000 Subject: Initial cut --- synapse/handlers/events.py | 43 +- synapse/handlers/message.py | 14 +- synapse/handlers/presence.py | 1662 ++++++++------------ synapse/handlers/profile.py | 3 + synapse/handlers/sync.py | 22 + synapse/rest/client/v1/presence.py | 26 +- synapse/rest/client/v1/room.py | 18 +- synapse/rest/client/v2_alpha/receipts.py | 3 + synapse/rest/client/v2_alpha/sync.py | 16 +- synapse/storage/__init__.py | 50 +- synapse/storage/prepare_database.py | 2 +- synapse/storage/presence.py | 170 +- .../storage/schema/delta/30/presence_stream.sql | 30 + synapse/storage/util/id_generators.py | 4 +- synapse/util/__init__.py | 2 +- tests/utils.py | 4 +- 16 files changed, 933 insertions(+), 1136 deletions(-) create mode 100644 synapse/storage/schema/delta/30/presence_stream.sql (limited to 'tests/utils.py') diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 4933c31c19..72a31a9755 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -19,6 +19,8 @@ from synapse.util.logutils import log_function from synapse.types import UserID from synapse.events.utils import serialize_event from synapse.util.logcontext import preserve_context_over_fn +from synapse.api.constants import Membership, EventTypes +from synapse.events import EventBase from ._base import BaseHandler @@ -126,11 +128,12 @@ class EventStreamHandler(BaseHandler): If `only_keys` is not None, events from keys will be sent down. """ auth_user = UserID.from_string(auth_user_id) + presence_handler = self.hs.get_handlers().presence_handler - try: - if affect_presence: - yield self.started_stream(auth_user) - + context = yield presence_handler.user_syncing( + auth_user_id, affect_presence=affect_presence, + ) + with context: if timeout: # If they've set a timeout set a minimum limit. timeout = max(timeout, 500) @@ -145,6 +148,34 @@ class EventStreamHandler(BaseHandler): is_guest=is_guest, explicit_room_id=room_id ) + # When the user joins a new room, or another user joins a currently + # joined room, we need to send down presence for those users. + to_add = [] + for event in events: + if not isinstance(event, EventBase): + continue + if event.type == EventTypes.Member: + if event.membership != Membership.JOIN: + continue + # Send down presence. + if event.state_key == auth_user_id: + # Send down presence for everyone in the room. + users = yield self.store.get_users_in_room(event.room_id) + states = yield presence_handler.get_states( + users, + as_event=True, + ) + to_add.extend(states) + else: + + ev = yield presence_handler.get_state( + UserID.from_string(event.state_key), + as_event=True, + ) + to_add.append(ev) + + events.extend(to_add) + time_now = self.clock.time_msec() chunks = [ @@ -159,10 +190,6 @@ class EventStreamHandler(BaseHandler): defer.returnValue(chunk) - finally: - if affect_presence: - self.stopped_stream(auth_user) - class EventHandler(BaseHandler): diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 82c8cb5f0c..77894d9132 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -21,7 +21,6 @@ from synapse.streams.config import PaginationConfig from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.util import unwrapFirstError -from synapse.util.logcontext import PreserveLoggingContext from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.types import UserID, RoomStreamToken, StreamToken @@ -254,8 +253,7 @@ class MessageHandler(BaseHandler): if event.type == EventTypes.Message: presence = self.hs.get_handlers().presence_handler - with PreserveLoggingContext(): - presence.bump_presence_active_time(user) + yield presence.bump_presence_active_time(user) @defer.inlineCallbacks def create_and_send_event(self, event_dict, ratelimit=True, @@ -660,10 +658,6 @@ class MessageHandler(BaseHandler): room_id=room_id, ) - # TODO(paul): I wish I was called with user objects not user_id - # strings... - auth_user = UserID.from_string(user_id) - # TODO: These concurrently time_now = self.clock.time_msec() state = [ @@ -688,13 +682,11 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def get_presence(): states = yield presence_handler.get_states( - target_users=[UserID.from_string(m.user_id) for m in room_members], - auth_user=auth_user, + [m.user_id for m in room_members], as_event=True, - check_auth=False, ) - defer.returnValue(states.values()) + defer.returnValue(states) @defer.inlineCallbacks def get_receipts(): diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index b61394f2b5..26f2e669ce 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -13,13 +13,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +"""This module is responsible for keeping track of presence status of local +and remote users. -from synapse.api.errors import SynapseError, AuthError +The methods that define policy are: + - PresenceHandler._update_states + - PresenceHandler._handle_timeouts + - should_notify +""" + +from twisted.internet import defer, reactor +from contextlib import contextmanager + +from synapse.api.errors import SynapseError from synapse.api.constants import PresenceState +from synapse.storage.presence import UserPresenceState -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logcontext import preserve_fn from synapse.util.logutils import log_function +from synapse.util.wheel_timer import WheelTimer from synapse.types import UserID import synapse.metrics @@ -33,33 +45,24 @@ logger = logging.getLogger(__name__) metrics = synapse.metrics.get_metrics_for(__name__) -# Don't bother bumping "last active" time if it differs by less than 60 seconds +# If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them +# "currently_active" LAST_ACTIVE_GRANULARITY = 60 * 1000 -# Keep no more than this number of offline serial revisions -MAX_OFFLINE_SERIALS = 1000 - - -# TODO(paul): Maybe there's one of these I can steal from somewhere -def partition(l, func): - """Partition the list by the result of func applied to each element.""" - ret = {} +# How long to wait until a new /events or /sync request before assuming +# the client has gone. +SYNC_ONLINE_TIMEOUT = 30 * 1000 - for x in l: - key = func(x) - if key not in ret: - ret[key] = [] - ret[key].append(x) +# How long to wait before marking the user as idle. Compared against last active +IDLE_TIMER = 5 * 60 * 1000 - return ret +# How often we expect remote servers to resend us presence. +FEDERATION_TIMEOUT = 30 * 60 * 1000 +# How often to resend presence to remote servers +FEDERATION_PING_INTERVAL = 25 * 60 * 1000 -def partitionbool(l, func): - def boolfunc(x): - return bool(func(x)) - - ret = partition(l, boolfunc) - return ret.get(True, []), ret.get(False, []) +assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER def user_presence_changed(distributor, user, statuscache): @@ -72,45 +75,13 @@ def collect_presencelike_data(distributor, user, content): class PresenceHandler(BaseHandler): - STATE_LEVELS = { - PresenceState.OFFLINE: 0, - PresenceState.UNAVAILABLE: 1, - PresenceState.ONLINE: 2, - PresenceState.FREE_FOR_CHAT: 3, - } - def __init__(self, hs): super(PresenceHandler, self).__init__(hs) - - self.homeserver = hs - + self.hs = hs self.clock = hs.get_clock() - - distributor = hs.get_distributor() - distributor.observe("registered_user", self.registered_user) - - distributor.observe( - "started_user_eventstream", self.started_user_eventstream - ) - distributor.observe( - "stopped_user_eventstream", self.stopped_user_eventstream - ) - - distributor.observe("user_joined_room", self.user_joined_room) - - distributor.declare("collect_presencelike_data") - - distributor.declare("changed_presencelike_data") - distributor.observe( - "changed_presencelike_data", self.changed_presencelike_data - ) - - # outbound signal from the presence module to advertise when a user's - # presence has changed - distributor.declare("user_presence_changed") - - self.distributor = distributor - + self.store = hs.get_datastore() + self.wheel_timer = WheelTimer() + self.notifier = hs.get_notifier() self.federation = hs.get_replication_layer() self.federation.register_edu_handler( @@ -138,348 +109,574 @@ class PresenceHandler(BaseHandler): ) ) - # IN-MEMORY store, mapping local userparts to sets of local users to - # be informed of state changes. - self._local_pushmap = {} - # map local users to sets of remote /domain names/ who are interested - # in them - self._remote_sendmap = {} - # map remote users to sets of local users who're interested in them - self._remote_recvmap = {} - # list of (serial, set of(userids)) tuples, ordered by serial, latest - # first - self._remote_offline_serials = [] - - # map any user to a UserPresenceCache - self._user_cachemap = {} - self._user_cachemap_latest_serial = 0 - - # map room_ids to the latest presence serial for a member of that - # room - self._room_serials = {} - - metrics.register_callback( - "userCachemap:size", - lambda: len(self._user_cachemap), + distributor = hs.get_distributor() + distributor.observe("user_joined_room", self.user_joined_room) + + active_presence = self.store.take_presence_startup_info() + + # A dictionary of the current state of users. This is prefilled with + # non-offline presence from the DB. We should fetch from the DB if + # we can't find a users presence in here. + self.user_to_current_state = { + state.user_id: state + for state in active_presence + } + + now = self.clock.time_msec() + for state in active_presence: + self.wheel_timer.insert( + now=now, + obj=state.user_id, + then=state.last_active + IDLE_TIMER, + ) + self.wheel_timer.insert( + now=now, + obj=state.user_id, + then=state.last_user_sync + SYNC_ONLINE_TIMEOUT, + ) + if self.hs.is_mine_id(state.user_id): + self.wheel_timer.insert( + now=now, + obj=state.user_id, + then=state.last_federation_update + FEDERATION_PING_INTERVAL, + ) + else: + self.wheel_timer.insert( + now=now, + obj=state.user_id, + then=state.last_federation_update + FEDERATION_TIMEOUT, + ) + + # Set of users who have presence in the `user_to_current_state` that + # have not yet been persisted + self.unpersisted_users_changes = set() + + reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown) + + self.serial_to_user = {} + self._next_serial = 1 + + # Keeps track of the number of *ongoing* syncs. While this is non zero + # a user will never go offline. + self.user_to_num_current_syncs = {} + + # Start a LoopingCall in 30s that fires every 5s. + # The initial delay is to allow disconnected clients a chance to + # reconnect before we treat them as offline. + self.clock.call_later( + 0 * 1000, + self.clock.looping_call, + self._handle_timeouts, + 5000, ) - def _get_or_make_usercache(self, user): - """If the cache entry doesn't exist, initialise a new one.""" - if user not in self._user_cachemap: - self._user_cachemap[user] = UserPresenceCache() - return self._user_cachemap[user] - - def _get_or_offline_usercache(self, user): - """If the cache entry doesn't exist, return an OFFLINE one but do not - store it into the cache.""" - if user in self._user_cachemap: - return self._user_cachemap[user] - else: - return UserPresenceCache() + @defer.inlineCallbacks + def _on_shutdown(self): + """Gets called when shutting down. This lets us persist any updates that + we haven't yet persisted, e.g. updates that only changes some internal + timers. This allows changes to persist across startup without having to + persist every single change. + + If this does not run it simply means that some of the timers will fire + earlier than they should when synapse is restarted. This affect of this + is some spurious presence changes that will self-correct. + """ + logger.info( + "Performing _on_shutdown. Persiting %d unpersisted changes", + len(self.user_to_current_state) + ) - def registered_user(self, user): - return self.store.create_presence(user.localpart) + if self.unpersisted_users_changes: + yield self.store.update_presence([ + self.user_to_current_state[user_id] + for user_id in self.unpersisted_users_changes + ]) + logger.info("Finished _on_shutdown") @defer.inlineCallbacks - def is_presence_visible(self, observer_user, observed_user): - assert(self.hs.is_mine(observed_user)) + def _update_states(self, new_states): + """Updates presence of users. Sets the appropriate timeouts. Pokes + the notifier and federation if and only if the changed presence state + should be sent to clients/servers. + """ + now = self.clock.time_msec() - if observer_user == observed_user: - defer.returnValue(True) + # NOTE: We purposefully don't yield between now and when we've + # calculated what we want to do with the new states, to avoid races. - if (yield self.store.user_rooms_intersect( - [u.to_string() for u in observer_user, observed_user])): - defer.returnValue(True) + to_notify = {} # Changes we want to notify everyone about + to_federation_ping = {} # These need sending keep-alives + for new_state in new_states: + user_id = new_state.user_id + prev_state = self.user_to_current_state.get( + user_id, UserPresenceState.default(user_id) + ) - if (yield self.store.is_presence_visible( - observed_localpart=observed_user.localpart, - observer_userid=observer_user.to_string())): - defer.returnValue(True) + # If the users are ours then we want to set up a bunch of timers + # to time things out. + if self.hs.is_mine_id(user_id): + if new_state.state == PresenceState.ONLINE: + # Idle timer + self.wheel_timer.insert( + now=now, + obj=user_id, + then=new_state.last_active + IDLE_TIMER + ) - defer.returnValue(False) + if new_state.state != PresenceState.OFFLINE: + # User has stopped syncing + self.wheel_timer.insert( + now=now, + obj=user_id, + then=new_state.last_user_sync + SYNC_ONLINE_TIMEOUT + ) - @defer.inlineCallbacks - def get_state(self, target_user, auth_user, as_event=False, check_auth=True): - """Get the current presence state of the given user. + last_federate = new_state.last_federation_update + if now - last_federate > FEDERATION_PING_INTERVAL: + # Been a while since we've poked remote servers + new_state = new_state.copy_and_replace( + last_federation_update=now, + ) + to_federation_ping[user_id] = new_state - Args: - target_user (UserID): The user whose presence we want - auth_user (UserID): The user requesting the presence, used for - checking if said user is allowed to see the persence of the - `target_user` - as_event (bool): Format the return as an event or not? - check_auth (bool): Perform the auth checks or not? + else: + self.wheel_timer.insert( + now=now, + obj=user_id, + then=new_state.last_federation_update + FEDERATION_TIMEOUT + ) - Returns: - dict: The presence state of the `target_user`, whose format depends - on the `as_event` argument. - """ - if self.hs.is_mine(target_user): - if check_auth: - visible = yield self.is_presence_visible( - observer_user=auth_user, - observed_user=target_user + if new_state.state == PresenceState.ONLINE: + currently_active = now - new_state.last_active < LAST_ACTIVE_GRANULARITY + new_state = new_state.copy_and_replace( + currently_active=currently_active, ) - if not visible: - raise SynapseError(404, "Presence information not visible") + # Check whether the change was something worth notifying about + if should_notify(prev_state, new_state): + new_state.copy_and_replace( + last_federation_update=now, + ) + to_notify[user_id] = new_state - if target_user in self._user_cachemap: - state = self._user_cachemap[target_user].get_state() - else: - state = yield self.store.get_presence_state(target_user.localpart) - if "mtime" in state: - del state["mtime"] - state["presence"] = state.pop("state") - else: - # TODO(paul): Have remote server send us permissions set - state = self._get_or_offline_usercache(target_user).get_state() + self.user_to_current_state[user_id] = new_state + + # TODO: We should probably ensure there are no races hereafter - if "last_active" in state: - state["last_active_ago"] = int( - self.clock.time_msec() - state.pop("last_active") + if to_notify: + yield self._persist_and_notify(to_notify.values()) + + self.unpersisted_users_changes |= set(s.user_id for s in new_states) + self.unpersisted_users_changes -= set(to_notify.keys()) + + to_federation_ping = { + user_id: state for user_id, state in to_federation_ping.items() + if user_id not in to_notify + } + if to_federation_ping: + _, _, hosts_to_states = yield self._get_interested_parties( + to_federation_ping.values() ) - if as_event: - content = state + self._push_to_remotes(hosts_to_states) - content["user_id"] = target_user.to_string() + def _handle_timeouts(self): + """Checks the presence of users that have timed out and updates as + appropriate. + """ + now = self.clock.time_msec() - if "last_active" in content: - content["last_active_ago"] = int( - self._clock.time_msec() - content.pop("last_active") - ) + # Fetch the list of users that *may* have timed out. Things may have + # changed since the timeout was set, so we won't necessarily have to + # take any action. + users_to_check = self.wheel_timer.fetch(now) - defer.returnValue({"type": "m.presence", "content": content}) - else: - defer.returnValue(state) + changes = {} # Actual changes we need to notify people about - @defer.inlineCallbacks - def get_states(self, target_users, auth_user, as_event=False, check_auth=True): - """A batched version of the `get_state` method that accepts a list of - `target_users` + for user_id in set(users_to_check): + state = self.user_to_current_state.get(user_id, None) + if not state: + continue - Args: - target_users (list): The list of UserID's whose presence we want - auth_user (UserID): The user requesting the presence, used for - checking if said user is allowed to see the persence of the - `target_users` - as_event (bool): Format the return as an event or not? - check_auth (bool): Perform the auth checks or not? + if self.hs.is_mine_id(user_id): + if state.state == PresenceState.OFFLINE: + continue - Returns: - dict: A mapping from user -> presence_state + if state.state == PresenceState.ONLINE: + if now - state.last_active > IDLE_TIMER: + # Currently online, but last activity ages ago so auto + # idle + changes[user_id] = state.copy_and_replace( + state=PresenceState.UNAVAILABLE, + ) + elif now - state.last_active > LAST_ACTIVE_GRANULARITY: + # So that we send down a notification that we've + # stopped updating. + changes[user_id] = state + + if now - state.last_federation_update > FEDERATION_PING_INTERVAL: + # Need to send ping to other servers to ensure they don't + # timeout and set us to offline + changes[user_id] = state + + # If there are have been no sync for a while (and none ongoing), + # set presence to offline + if not self.user_to_num_current_syncs.get(user_id, 0): + if now - state.last_user_sync > SYNC_ONLINE_TIMEOUT: + changes[user_id] = state.copy_and_replace( + state=PresenceState.OFFLINE, + ) + else: + # We expect to be poked occaisonally by the other side. + # This is to protect against forgetful/buggy servers, so that + # no one gets stuck online forever. + if now - state.last_federation_update > FEDERATION_TIMEOUT: + if state.state != PresenceState.OFFLINE: + # The other side seems to have disappeared. + changes[user_id] = state.copy_and_replace( + state=PresenceState.OFFLINE, + ) + + preserve_fn(self._update_states)(changes.values()) + + @defer.inlineCallbacks + def bump_presence_active_time(self, user): + """We've seen the user do something that indicates they're interacting + with the app. """ - local_users, remote_users = partitionbool( - target_users, - lambda u: self.hs.is_mine(u) - ) + user_id = user.to_string() - if check_auth: - for user in local_users: - visible = yield self.is_presence_visible( - observer_user=auth_user, - observed_user=user - ) + prev_state = yield self.current_state_for_user(user_id) - if not visible: - raise SynapseError(404, "Presence information not visible") + yield self._update_states([prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active=self.clock.time_msec(), + )]) - results = {} - if local_users: - for user in local_users: - if user in self._user_cachemap: - results[user] = self._user_cachemap[user].get_state() + @defer.inlineCallbacks + def user_syncing(self, user_id, affect_presence=True): + """Returns a context manager that should surround any stream requests + from the user. - local_to_user = {u.localpart: u for u in local_users} + This allows us to keep track of who is currently streaming and who isn't + without having to have timers outside of this module to avoid flickering + when users disconnect/reconnect. - states = yield self.store.get_presence_states( - [u.localpart for u in local_users if u not in results] - ) + Args: + user_id (str) + affect_presence (bool): If false this function will be a no-op. + Useful for streams that are not associated with an actual + client that is being used by a user. + """ + if affect_presence: + curr_sync = self.user_to_num_current_syncs.get(user_id, 0) + self.user_to_num_current_syncs[user_id] = curr_sync + 1 + + prev_state = yield self.current_state_for_user(user_id) + if prev_state.state == PresenceState.OFFLINE: + # If they're currently offline then bring them online, otherwise + # just update the last sync times. + yield self._update_states([prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active=self.clock.time_msec(), + last_user_sync=self.clock.time_msec(), + )]) + else: + yield self._update_states([prev_state.copy_and_replace( + last_user_sync=self.clock.time_msec(), + )]) - for local_part, state in states.items(): - if state is None: - continue - res = {"presence": state["state"]} - if "status_msg" in state and state["status_msg"]: - res["status_msg"] = state["status_msg"] - results[local_to_user[local_part]] = res - - for user in remote_users: - # TODO(paul): Have remote server send us permissions set - results[user] = self._get_or_offline_usercache(user).get_state() - - for state in results.values(): - if "last_active" in state: - state["last_active_ago"] = int( - self.clock.time_msec() - state.pop("last_active") - ) + @defer.inlineCallbacks + def _end(): + if affect_presence: + self.user_to_num_current_syncs[user_id] -= 1 - if as_event: - for user, state in results.items(): - content = state - content["user_id"] = user.to_string() + prev_state = yield self.current_state_for_user(user_id) + yield self._update_states([prev_state.copy_and_replace( + last_user_sync=self.clock.time_msec(), + )]) - if "last_active" in content: - content["last_active_ago"] = int( - self._clock.time_msec() - content.pop("last_active") - ) + @contextmanager + def _user_syncing(): + try: + yield + finally: + preserve_fn(_end)() - results[user] = {"type": "m.presence", "content": content} + defer.returnValue(_user_syncing()) - defer.returnValue(results) + @defer.inlineCallbacks + def current_state_for_user(self, user_id): + """Get the current presence state for a user. + """ + res = yield self.current_state_for_users([user_id]) + defer.returnValue(res[user_id]) @defer.inlineCallbacks - @log_function - def set_state(self, target_user, auth_user, state): - # return - # TODO (erikj): Turn this back on. Why did we end up sending EDUs - # everywhere? + def current_state_for_users(self, user_ids): + """Get the current presence state for multiple users. - if not self.hs.is_mine(target_user): - raise SynapseError(400, "User is not hosted on this Home Server") + Returns: + dict: `user_id` -> `UserPresenceState` + """ + states = { + user_id: self.user_to_current_state.get(user_id, None) + for user_id in user_ids + } + + missing = [user_id for user_id, state in states.items() if not state] + if missing: + # There are things not in our in memory cache. Lets pull them out of + # the database. + res = yield self.store.get_presence_for_users(missing) + states.update({state.user_id: state for state in res}) + + missing = [user_id for user_id, state in states.items() if not state] + if missing: + states.update({ + user_id: UserPresenceState.default(user_id) + for user_id in missing + }) - if target_user != auth_user: - raise AuthError(400, "Cannot set another user's presence") + defer.returnValue(states) - if "status_msg" not in state: - state["status_msg"] = None + @defer.inlineCallbacks + def _get_interested_parties(self, states): + """Given a list of states return which entities (rooms, users, servers) + are interested in the given states. - for k in state.keys(): - if k not in ("presence", "status_msg"): - raise SynapseError( - 400, "Unexpected presence state key '%s'" % (k,) - ) + Returns: + 3-tuple: `(room_ids_to_states, users_to_states, hosts_to_states)`, + with each item being a dict of `entity_name` -> `[UserPresenceState]` + """ + room_ids_to_states = {} + users_to_states = {} + for state in states: + events = yield self.store.get_rooms_for_user(state.user_id) + for e in events: + room_ids_to_states.setdefault(e.room_id, []).append(state) + + plist = yield self.store.get_presence_list_observers_accepted(state.user_id) + for u in plist: + users_to_states.setdefault(u, []).append(state) + + # Always notify self + users_to_states.setdefault(state.user_id, []).append(state) + + hosts_to_states = {} + for room_id, states in room_ids_to_states.items(): + hosts = yield self.store.get_joined_hosts_for_room(room_id) + for host in hosts: + hosts_to_states.setdefault(host, []).extend(states) - if state["presence"] not in self.STATE_LEVELS: - raise SynapseError(400, "'%s' is not a valid presence state" % ( - state["presence"], - )) + for user_id, states in users_to_states.items(): + host = UserID.from_string(user_id).domain + hosts_to_states.setdefault(host, []).extend(states) - logger.debug("Updating presence state of %s to %s", - target_user.localpart, state["presence"]) + # TODO: de-dup hosts_to_states, as a single host might have multiple + # of same presence - state_to_store = dict(state) - state_to_store["state"] = state_to_store.pop("presence") + defer.returnValue((room_ids_to_states, users_to_states, hosts_to_states)) + + @defer.inlineCallbacks + def _persist_and_notify(self, states): + """Persist states in the database, poke the notifier and send to + interested remote servers + """ + stream_id, max_token = yield self.store.update_presence(states) - statuscache = self._get_or_offline_usercache(target_user) - was_level = self.STATE_LEVELS[statuscache.get_state()["presence"]] - now_level = self.STATE_LEVELS[state["presence"]] + parties = yield self._get_interested_parties(states) + room_ids_to_states, users_to_states, hosts_to_states = parties - yield self.store.set_presence_state( - target_user.localpart, state_to_store + self.notifier.on_new_event( + "presence_key", stream_id, rooms=room_ids_to_states.keys(), + users=[UserID.from_string(u) for u in users_to_states.keys()] ) - yield collect_presencelike_data(self.distributor, target_user, state) - if now_level > was_level: - state["last_active"] = self.clock.time_msec() + self._push_to_remotes(hosts_to_states) + + def _push_to_remotes(self, hosts_to_states): + """Sends state updates to remote servers. + + Args: + hosts_to_states (dict): Mapping `server_name` -> `[UserPresenceState]` + """ + now = self.clock.time_msec() + for host, states in hosts_to_states.items(): + self.federation.send_edu( + destination=host, + edu_type="m.presence", + content={ + "push": [ + _format_user_presence_state(state, now) + for state in states + ] + } + ) + + @defer.inlineCallbacks + def incoming_presence(self, origin, content): + """Called when we receive a `m.presence` EDU from a remote server. + """ + now = self.clock.time_msec() + updates = [] + for push in content.get("push", []): + # A "push" contains a list of presence that we are probably interested + # in. + # TODO: Actually check if we're interested, rather than blindly + # accepting presence updates. + user_id = push.get("user_id", None) + if not user_id: + logger.info( + "Got presence update from %r with no 'user_id': %r", + origin, push, + ) + continue - now_online = state["presence"] != PresenceState.OFFLINE - was_polling = target_user in self._user_cachemap + presence_state = push.get("presence", None) + if not presence_state: + logger.info( + "Got presence update from %r with no 'presence_state': %r", + origin, push, + ) + continue - if now_online and not was_polling: - yield self.start_polling_presence(target_user, state=state) - elif not now_online and was_polling: - yield self.stop_polling_presence(target_user) + new_fields = { + "state": presence_state, + "last_federation_update": now, + } - # TODO(paul): perform a presence push as part of start/stop poll so - # we don't have to do this all the time - yield self.changed_presencelike_data(target_user, state) + last_active_ago = push.get("last_active_ago", None) + if last_active_ago is not None: + new_fields["last_active"] = now - last_active_ago - def bump_presence_active_time(self, user, now=None): - if now is None: - now = self.clock.time_msec() + new_fields["status_msg"] = push.get("status_msg", None) - prev_state = self._get_or_make_usercache(user) - if now - prev_state.state.get("last_active", 0) < LAST_ACTIVE_GRANULARITY: - return + prev_state = yield self.current_state_for_user(user_id) + updates.append(prev_state.copy_and_replace(**new_fields)) - with PreserveLoggingContext(): - self.changed_presencelike_data(user, {"last_active": now}) + if updates: + yield self._update_states(updates) - def get_joined_rooms_for_user(self, user): - """Get the list of rooms a user is joined to. + @defer.inlineCallbacks + def get_state(self, target_user, as_event=False): + results = yield self.get_states( + [target_user.to_string()], + as_event=as_event, + ) + + defer.returnValue(results[0]) + + @defer.inlineCallbacks + def get_states(self, target_user_ids, as_event=False): + """Get the presence state for users. Args: - user(UserID): The user. + target_user_ids (list) + as_event (bool): Whether to format it as a client event or not. + Returns: - A Deferred of a list of room id strings. + list """ - rm_handler = self.homeserver.get_handlers().room_member_handler - return rm_handler.get_joined_rooms_for_user(user) - def get_joined_users_for_room_id(self, room_id): - rm_handler = self.homeserver.get_handlers().room_member_handler - return rm_handler.get_room_members(room_id) + updates = yield self.current_state_for_users(target_user_ids) + updates = updates.values() - @defer.inlineCallbacks - def changed_presencelike_data(self, user, state): - """Updates the presence state of a local user. + for user_id in set(target_user_ids) - set(u.user_id for u in updates): + updates.append(UserPresenceState.default(user_id)) - Args: - user(UserID): The user being updated. - state(dict): The new presence state for the user. - Returns: - A Deferred + now = self.clock.time_msec() + if as_event: + defer.returnValue([ + { + "type": "m.presence", + "content": _format_user_presence_state(state, now), + } + for state in updates + ]) + else: + defer.returnValue([ + _format_user_presence_state(state, now) for state in updates + ]) + + @defer.inlineCallbacks + def set_state(self, target_user, state): + """Set the presence state of the user. """ - self._user_cachemap_latest_serial += 1 - statuscache = yield self.update_presence_cache(user, state) - yield self.push_presence(user, statuscache=statuscache) + status_msg = state.get("status_msg", None) + presence = state["presence"] - @log_function - def started_user_eventstream(self, user): - # TODO(paul): Use "last online" state - return self.set_state(user, user, {"presence": PresenceState.ONLINE}) + user_id = target_user.to_string() - @log_function - def stopped_user_eventstream(self, user): - # TODO(paul): Save current state as "last online" state - return self.set_state(user, user, {"presence": PresenceState.OFFLINE}) + prev_state = yield self.current_state_for_user(user_id) + + new_fields = { + "state": presence, + "status_msg": status_msg + } + + if presence == PresenceState.ONLINE: + new_fields["last_active"] = self.clock.time_msec() + + yield self._update_states([prev_state.copy_and_replace(**new_fields)]) @defer.inlineCallbacks def user_joined_room(self, user, room_id): - """Called via the distributor whenever a user joins a room. - Notifies the new member of the presence of the current members. - Notifies the current members of the room of the new member's presence. - - Args: - user(UserID): The user who joined the room. - room_id(str): The room id the user joined. + """Called (via the distributor) when a user joins a room. This funciton + sends presence updates to servers, either: + 1. the joining user is a local user and we send their presence to + all servers in the room. + 2. the joining user is a remote user and so we send presence for all + local users in the room. """ + # We only need to send presence to servers that don't have it yet. We + # don't need to send to local clients here, as that is done as part + # of the event stream/sync. + # TODO: Only send to servers not already in the room. if self.hs.is_mine(user): - # No actual update but we need to bump the serial anyway for the - # event source - self._user_cachemap_latest_serial += 1 - statuscache = yield self.update_presence_cache( - user, room_ids=[room_id] - ) - self.push_update_to_local_and_remote( - observed_user=user, - room_ids=[room_id], - statuscache=statuscache, - ) + state = yield self.current_state_for_user(user.to_string()) - # We also want to tell them about current presence of people. - curr_users = yield self.get_joined_users_for_room_id(room_id) + hosts = yield self.store.get_joined_hosts_for_room(room_id) + self._push_to_remotes({host: (state,) for host in hosts}) + else: + user_ids = yield self.store.get_users_in_room(room_id) + user_ids = filter(self.hs.is_mine_id, user_ids) - for local_user in [c for c in curr_users if self.hs.is_mine(c)]: - statuscache = yield self.update_presence_cache( - local_user, room_ids=[room_id], add_to_cache=False - ) + states = yield self.current_state_for_users(user_ids) - with PreserveLoggingContext(): - self.push_update_to_local_and_remote( - observed_user=local_user, - users_to_push=[user], - statuscache=statuscache, - ) + self._push_to_remotes({user.domain: states.values()}) @defer.inlineCallbacks - def send_presence_invite(self, observer_user, observed_user): - """Request the presence of a local or remote user for a local user""" + def get_presence_list(self, observer_user, accepted=None): + """Returns the presence for all users in their presence list. + """ if not self.hs.is_mine(observer_user): raise SynapseError(400, "User is not hosted on this Home Server") + presence_list = yield self.store.get_presence_list( + observer_user.localpart, accepted=accepted + ) + + results = yield self.get_states( + target_user_ids=[row["observed_user_id"] for row in presence_list], + as_event=False, + ) + + is_accepted = { + row["observed_user_id"]: row["accepted"] for row in presence_list + } + + for result in results: + result.update({ + "accepted": is_accepted, + }) + + defer.returnValue(results) + + @defer.inlineCallbacks + def send_presence_invite(self, observer_user, observed_user): + """Sends a presence invite. + """ yield self.store.add_presence_list_pending( observer_user.localpart, observed_user.to_string() ) @@ -496,60 +693,41 @@ class PresenceHandler(BaseHandler): } ) - @defer.inlineCallbacks - def _should_accept_invite(self, observed_user, observer_user): - if not self.hs.is_mine(observed_user): - defer.returnValue(False) - - row = yield self.store.has_presence_state(observed_user.localpart) - if not row: - defer.returnValue(False) - - # TODO(paul): Eventually we'll ask the user's permission for this - # before accepting. For now just accept any invite request - defer.returnValue(True) - @defer.inlineCallbacks def invite_presence(self, observed_user, observer_user): - """Handles a m.presence_invite EDU. A remote or local user has - requested presence updates for a local user. If the invite is accepted - then allow the local or remote user to see the presence of the local - user. - - Args: - observed_user(UserID): The local user whose presence is requested. - observer_user(UserID): The remote or local user requesting presence. + """Handles new presence invites. """ - accept = yield self._should_accept_invite(observed_user, observer_user) - - if accept: - yield self.store.allow_presence_visible( - observed_user.localpart, observer_user.to_string() - ) + if not self.hs.is_mine(observed_user): + raise SynapseError(400, "User is not hosted on this Home Server") + # TODO: Don't auto accept if self.hs.is_mine(observer_user): - if accept: - yield self.accept_presence(observed_user, observer_user) - else: - yield self.deny_presence(observed_user, observer_user) + yield self.accept_presence(observed_user, observer_user) else: - edu_type = "m.presence_accept" if accept else "m.presence_deny" - - yield self.federation.send_edu( + self.federation.send_edu( destination=observer_user.domain, - edu_type=edu_type, + edu_type="m.presence_accept", content={ "observed_user": observed_user.to_string(), "observer_user": observer_user.to_string(), } ) + state_dict = yield self.get_state(observed_user, as_event=False) + + self.federation.send_edu( + destination=observer_user.domain, + edu_type="m.presence", + content={ + "push": [state_dict] + } + ) + @defer.inlineCallbacks def accept_presence(self, observed_user, observer_user): """Handles a m.presence_accept EDU. Mark a presence invite from a local or remote user as accepted in a local user's presence list. Starts polling for presence updates from the local or remote user. - Args: observed_user(UserID): The user to update in the presence list. observer_user(UserID): The owner of the presence list to update. @@ -558,15 +736,10 @@ class PresenceHandler(BaseHandler): observer_user.localpart, observed_user.to_string() ) - yield self.start_polling_presence( - observer_user, target_user=observed_user - ) - @defer.inlineCallbacks def deny_presence(self, observed_user, observer_user): """Handle a m.presence_deny EDU. Removes a local or remote user from a local user's presence list. - Args: observed_user(UserID): The local or remote user to remove from the list. @@ -584,7 +757,6 @@ class PresenceHandler(BaseHandler): def drop(self, observed_user, observer_user): """Remove a local or remote user from a local user's presence list and unsubscribe the local user from updates that user. - Args: observed_user(UserId): The local or remote user to remove from the list. @@ -599,710 +771,138 @@ class PresenceHandler(BaseHandler): observer_user.localpart, observed_user.to_string() ) - self.stop_polling_presence( - observer_user, target_user=observed_user - ) - - @defer.inlineCallbacks - def get_presence_list(self, observer_user, accepted=None): - """Get the presence list for a local user. The retured list includes - the current presence state for each user listed. - - Args: - observer_user(UserID): The local user whose presence list to fetch. - accepted(bool or None): If not none then only include users who - have or have not accepted the presence invite request. - Returns: - A Deferred list of presence state events. - """ - if not self.hs.is_mine(observer_user): - raise SynapseError(400, "User is not hosted on this Home Server") - - presence_list = yield self.store.get_presence_list( - observer_user.localpart, accepted=accepted - ) - - results = [] - for row in presence_list: - observed_user = UserID.from_string(row["observed_user_id"]) - result = { - "observed_user": observed_user, "accepted": row["accepted"] - } - result.update( - self._get_or_offline_usercache(observed_user).get_state() - ) - if "last_active" in result: - result["last_active_ago"] = int( - self.clock.time_msec() - result.pop("last_active") - ) - results.append(result) - - defer.returnValue(results) - - @defer.inlineCallbacks - @log_function - def start_polling_presence(self, user, target_user=None, state=None): - """Subscribe a local user to presence updates from a local or remote - user. If no target_user is supplied then subscribe to all users stored - in the presence list for the local user. - - Additonally this pushes the current presence state of this user to all - target_users. That state can be provided directly or will be read from - the stored state for the local user. - - Also this attempts to notify the local user of the current state of - any local target users. - - Args: - user(UserID): The local user that whishes for presence updates. - target_user(UserID): The local or remote user whose updates are - wanted. - state(dict): Optional presence state for the local user. - """ - logger.debug("Start polling for presence from %s", user) - - if target_user: - target_users = set([target_user]) - room_ids = [] - else: - presence = yield self.store.get_presence_list( - user.localpart, accepted=True - ) - target_users = set([ - UserID.from_string(x["observed_user_id"]) for x in presence - ]) - - # Also include people in all my rooms - - room_ids = yield self.get_joined_rooms_for_user(user) - - if state is None: - state = yield self.store.get_presence_state(user.localpart) - else: - # statuscache = self._get_or_make_usercache(user) - # self._user_cachemap_latest_serial += 1 - # statuscache.update(state, self._user_cachemap_latest_serial) - pass - - yield self.push_update_to_local_and_remote( - observed_user=user, - users_to_push=target_users, - room_ids=room_ids, - statuscache=self._get_or_make_usercache(user), - ) - - for target_user in target_users: - if self.hs.is_mine(target_user): - self._start_polling_local(user, target_user) - - # We want to tell the person that just came online - # presence state of people they are interested in? - self.push_update_to_clients( - users_to_push=[user], - ) - - deferreds = [] - remote_users = [u for u in target_users if not self.hs.is_mine(u)] - remoteusers_by_domain = partition(remote_users, lambda u: u.domain) - # Only poll for people in our get_presence_list - for domain in remoteusers_by_domain: - remoteusers = remoteusers_by_domain[domain] - - deferreds.append(self._start_polling_remote( - user, domain, remoteusers - )) - - yield defer.DeferredList(deferreds, consumeErrors=True) - - def _start_polling_local(self, user, target_user): - """Subscribe a local user to presence updates for a local user - - Args: - user(UserId): The local user that wishes for updates. - target_user(UserId): The local users whose updates are wanted. - """ - target_localpart = target_user.localpart - - if target_localpart not in self._local_pushmap: - self._local_pushmap[target_localpart] = set() - - self._local_pushmap[target_localpart].add(user) - - def _start_polling_remote(self, user, domain, remoteusers): - """Subscribe a local user to presence updates for remote users on a - given remote domain. - - Args: - user(UserID): The local user that wishes for updates. - domain(str): The remote server the local user wants updates from. - remoteusers(UserID): The remote users that local user wants to be - told about. - Returns: - A Deferred. - """ - to_poll = set() - - for u in remoteusers: - if u not in self._remote_recvmap: - self._remote_recvmap[u] = set() - to_poll.add(u) - - self._remote_recvmap[u].add(user) - - if not to_poll: - return defer.succeed(None) - - return self.federation.send_edu( - destination=domain, - edu_type="m.presence", - content={"poll": [u.to_string() for u in to_poll]} - ) - - @log_function - def stop_polling_presence(self, user, target_user=None): - """Unsubscribe a local user from presence updates from a local or - remote user. If no target user is supplied then unsubscribe the user - from all presence updates that the user had subscribed to. - - Args: - user(UserID): The local user that no longer wishes for updates. - target_user(UserID or None): The user whose updates are no longer - wanted. - Returns: - A Deferred. - """ - logger.debug("Stop polling for presence from %s", user) - - if not target_user or self.hs.is_mine(target_user): - self._stop_polling_local(user, target_user=target_user) - - deferreds = [] - - if target_user: - if target_user not in self._remote_recvmap: - return - target_users = set([target_user]) - else: - target_users = self._remote_recvmap.keys() - - remoteusers = [u for u in target_users - if user in self._remote_recvmap[u]] - remoteusers_by_domain = partition(remoteusers, lambda u: u.domain) - - for domain in remoteusers_by_domain: - remoteusers = remoteusers_by_domain[domain] - - deferreds.append( - self._stop_polling_remote(user, domain, remoteusers) - ) - - return defer.DeferredList(deferreds, consumeErrors=True) - - def _stop_polling_local(self, user, target_user): - """Unsubscribe a local user from presence updates from a local user on - this server. - - Args: - user(UserID): The local user that no longer wishes for updates. - target_user(UserID): The user whose updates are no longer wanted. - """ - for localpart in self._local_pushmap.keys(): - if target_user and localpart != target_user.localpart: - continue - - if user in self._local_pushmap[localpart]: - self._local_pushmap[localpart].remove(user) - - if not self._local_pushmap[localpart]: - del self._local_pushmap[localpart] - - @log_function - def _stop_polling_remote(self, user, domain, remoteusers): - """Unsubscribe a local user from presence updates from remote users on - a given domain. - - Args: - user(UserID): The local user that no longer wishes for updates. - domain(str): The remote server to unsubscribe from. - remoteusers([UserID]): The users on that remote server that the - local user no longer wishes to be updated about. - Returns: - A Deferred. - """ - to_unpoll = set() - - for u in remoteusers: - self._remote_recvmap[u].remove(user) - - if not self._remote_recvmap[u]: - del self._remote_recvmap[u] - to_unpoll.add(u) - - if not to_unpoll: - return defer.succeed(None) - - return self.federation.send_edu( - destination=domain, - edu_type="m.presence", - content={"unpoll": [u.to_string() for u in to_unpoll]} - ) - - @defer.inlineCallbacks - @log_function - def push_presence(self, user, statuscache): - """ - Notify local and remote users of a change in presence of a local user. - Pushes the update to local clients and remote domains that are directly - subscribed to the presence of the local user. - Also pushes that update to any local user or remote domain that shares - a room with the local user. - - Args: - user(UserID): The local user whose presence was updated. - statuscache(UserPresenceCache): Cache of the user's presence state - Returns: - A Deferred. - """ - assert(self.hs.is_mine(user)) - - logger.debug("Pushing presence update from %s", user) - - localusers = set(self._local_pushmap.get(user.localpart, set())) - remotedomains = set(self._remote_sendmap.get(user.localpart, set())) - - # Reflect users' status changes back to themselves, so UIs look nice - # and also user is informed of server-forced pushes - localusers.add(user) - - room_ids = yield self.get_joined_rooms_for_user(user) - - if not localusers and not room_ids: - defer.returnValue(None) - - yield self.push_update_to_local_and_remote( - observed_user=user, - users_to_push=localusers, - remote_domains=remotedomains, - room_ids=room_ids, - statuscache=statuscache, - ) - yield user_presence_changed(self.distributor, user, statuscache) - - @defer.inlineCallbacks - def incoming_presence(self, origin, content): - """Handle an incoming m.presence EDU. - For each presence update in the "push" list update our local cache and - notify the appropriate local clients. Only clients that share a room - or are directly subscribed to the presence for a user should be - notified of the update. - For each subscription request in the "poll" list start pushing presence - updates to the remote server. - For unsubscribe request in the "unpoll" list stop pushing presence - updates to the remote server. - - Args: - orgin(str): The source of this m.presence EDU. - content(dict): The content of this m.presence EDU. - Returns: - A Deferred. - """ - deferreds = [] - - for push in content.get("push", []): - user = UserID.from_string(push["user_id"]) - - logger.debug("Incoming presence update from %s", user) - - observers = set(self._remote_recvmap.get(user, set())) - if observers: - logger.debug( - " | %d interested local observers %r", len(observers), observers - ) - - room_ids = yield self.get_joined_rooms_for_user(user) - if room_ids: - logger.debug(" | %d interested room IDs %r", len(room_ids), room_ids) - - state = dict(push) - del state["user_id"] - - if "presence" not in state: - logger.warning( - "Received a presence 'push' EDU from %s without a" - " 'presence' key", origin - ) - continue - - if "last_active_ago" in state: - state["last_active"] = int( - self.clock.time_msec() - state.pop("last_active_ago") - ) - - self._user_cachemap_latest_serial += 1 - yield self.update_presence_cache(user, state, room_ids=room_ids) - - if not observers and not room_ids: - logger.debug(" | no interested observers or room IDs") - continue - - self.push_update_to_clients( - users_to_push=observers, room_ids=room_ids - ) - - user_id = user.to_string() - - if state["presence"] == PresenceState.OFFLINE: - self._remote_offline_serials.insert( - 0, - (self._user_cachemap_latest_serial, set([user_id])) - ) - while len(self._remote_offline_serials) > MAX_OFFLINE_SERIALS: - self._remote_offline_serials.pop() # remove the oldest - if user in self._user_cachemap: - del self._user_cachemap[user] - else: - # Remove the user from remote_offline_serials now that they're - # no longer offline - for idx, elem in enumerate(self._remote_offline_serials): - (_, user_ids) = elem - user_ids.discard(user_id) - if not user_ids: - self._remote_offline_serials.pop(idx) - - for poll in content.get("poll", []): - user = UserID.from_string(poll) - - if not self.hs.is_mine(user): - continue - - # TODO(paul) permissions checks - - if user not in self._remote_sendmap: - self._remote_sendmap[user] = set() - - self._remote_sendmap[user].add(origin) - - deferreds.append(self._push_presence_remote(user, origin)) - - for unpoll in content.get("unpoll", []): - user = UserID.from_string(unpoll) - - if not self.hs.is_mine(user): - continue - - if user in self._remote_sendmap: - self._remote_sendmap[user].remove(origin) - - if not self._remote_sendmap[user]: - del self._remote_sendmap[user] - - yield defer.DeferredList(deferreds, consumeErrors=True) - - @defer.inlineCallbacks - def update_presence_cache(self, user, state={}, room_ids=None, - add_to_cache=True): - """Update the presence cache for a user with a new state and bump the - serial to the latest value. - - Args: - user(UserID): The user being updated - state(dict): The presence state being updated - room_ids(None or list of str): A list of room_ids to update. If - room_ids is None then fetch the list of room_ids the user is - joined to. - add_to_cache: Whether to add an entry to the presence cache if the - user isn't already in the cache. - Returns: - A Deferred UserPresenceCache for the user being updated. - """ - if room_ids is None: - room_ids = yield self.get_joined_rooms_for_user(user) - - for room_id in room_ids: - self._room_serials[room_id] = self._user_cachemap_latest_serial - if add_to_cache: - statuscache = self._get_or_make_usercache(user) - else: - statuscache = self._get_or_offline_usercache(user) - statuscache.update(state, serial=self._user_cachemap_latest_serial) - defer.returnValue(statuscache) + # TODO: Inform the remote that we've dropped the presence list. @defer.inlineCallbacks - def push_update_to_local_and_remote(self, observed_user, statuscache, - users_to_push=[], room_ids=[], - remote_domains=[]): - """Notify local clients and remote servers of a change in the presence - of a user. - - Args: - observed_user(UserID): The user to push the presence state for. - statuscache(UserPresenceCache): The cache for the presence state to - push. - users_to_push([UserID]): A list of local and remote users to - notify. - room_ids([str]): Notify the local and remote occupants of these - rooms. - remote_domains([str]): A list of remote servers to notify in - addition to those implied by the users_to_push and the - room_ids. - Returns: - A Deferred. - """ + def is_visible(self, observed_user, observer_user): + observer_rooms = yield self.store.get_rooms_for_user(observer_user.to_string()) + observed_rooms = yield self.store.get_rooms_for_user(observed_user.to_string()) - localusers, remoteusers = partitionbool( - users_to_push, - lambda u: self.hs.is_mine(u) - ) + observer_room_ids = set(r.room_id for r in observer_rooms) + observed_room_ids = set(r.room_id for r in observed_rooms) - localusers = set(localusers) + if observer_room_ids & observed_room_ids: + defer.returnValue(True) - self.push_update_to_clients( - users_to_push=localusers, room_ids=room_ids + accepted_observers = yield self.store.get_presence_list_observers_accepted( + observed_user.to_string() ) - remote_domains = set(remote_domains) - remote_domains |= set([r.domain for r in remoteusers]) - for room_id in room_ids: - remote_domains.update( - (yield self.store.get_joined_hosts_for_room(room_id)) - ) + defer.returnValue(observer_user.to_string() in accepted_observers) - remote_domains.discard(self.hs.hostname) - - deferreds = [] - for domain in remote_domains: - logger.debug(" | push to remote domain %s", domain) - deferreds.append( - self._push_presence_remote( - observed_user, domain, state=statuscache.get_state() - ) - ) - yield defer.DeferredList(deferreds, consumeErrors=True) +def should_notify(old_state, new_state): + """Decides if a presence state change should be sent to interested parties. + """ + if old_state.status_msg != new_state.status_msg: + return True - defer.returnValue((localusers, remote_domains)) + if old_state.state == PresenceState.ONLINE: + if new_state.state != PresenceState.ONLINE: + # Always notify for online -> anything + return True - def push_update_to_clients(self, users_to_push=[], room_ids=[]): - """Notify clients of a new presence event. + if new_state.currently_active != old_state.currently_active: + return True - Args: - users_to_push([UserID]): List of users to notify. - room_ids([str]): List of room_ids to notify. - """ - with PreserveLoggingContext(): - self.notifier.on_new_event( - "presence_key", - self._user_cachemap_latest_serial, - users_to_push, - room_ids, - ) + if new_state.last_active - old_state.last_active > LAST_ACTIVE_GRANULARITY: + # Always notify for a transition where last active gets bumped. + return True - @defer.inlineCallbacks - def _push_presence_remote(self, user, destination, state=None): - """Push a user's presence to a remote server. If a presence state event - that event is sent. Otherwise a new state event is constructed from the - stored presence state. - The last_active is replaced with last_active_ago in case the wallclock - time on the remote server is different to the time on this server. - Sends an EDU to the remote server with the current presence state. + if old_state.state != new_state.state: + # Nothing to report. + return True - Args: - user(UserID): The user to push the presence state for. - destination(str): The remote server to send state to. - state(dict): The state to push, or None to use the current stored - state. - Returns: - A Deferred. - """ - if state is None: - state = yield self.store.get_presence_state(user.localpart) - del state["mtime"] - state["presence"] = state.pop("state") - - if user in self._user_cachemap: - state["last_active"] = ( - self._user_cachemap[user].get_state()["last_active"] - ) + return False - yield collect_presencelike_data(self.distributor, user, state) - if "last_active" in state: - state = dict(state) - state["last_active_ago"] = int( - self.clock.time_msec() - state.pop("last_active") - ) - - user_state = {"user_id": user.to_string(), } - user_state.update(state) +def _format_user_presence_state(state, now): + """Convert UserPresenceState to a format that can be sent down to clients + and to other servers. + """ + content = { + "presence": state.state, + "user_id": state.user_id, + } + if state.last_active: + content["last_active_ago"] = now - state.last_active + if state.status_msg and state.state != PresenceState.OFFLINE: + content["status_msg"] = state.status_msg + if state.state == PresenceState.ONLINE: + content["currently_active"] = state.currently_active - yield self.federation.send_edu( - destination=destination, - edu_type="m.presence", - content={"push": [user_state, ], } - ) + return content class PresenceEventSource(object): def __init__(self, hs): self.hs = hs self.clock = hs.get_clock() + self.store = hs.get_datastore() @defer.inlineCallbacks @log_function - def get_new_events(self, user, from_key, room_ids=None, **kwargs): - from_key = int(from_key) + def get_new_events(self, user, from_key, room_ids=None, include_offline=True, + **kwargs): + # The process for getting presence events are: + # 1. Get the rooms the user is in. + # 2. Get the list of user in the rooms. + # 3. Get the list of users that are in the user's presence list. + # 4. If there is a from_key set, cross reference the list of users + # with the `presence_stream_cache` to see which ones we actually + # need to check. + # 5. Load current state for the users. + # + # We don't try and limit the presence updates by the current token, as + # sending down the rare duplicate is not a concern. + + user_id = user.to_string() + if from_key is not None: + from_key = int(from_key) room_ids = room_ids or [] presence = self.hs.get_handlers().presence_handler - cachemap = presence._user_cachemap - - max_serial = presence._user_cachemap_latest_serial - - clock = self.clock - latest_serial = 0 - - user_ids_to_check = {user} - presence_list = yield presence.store.get_presence_list( - user.localpart, accepted=True - ) - if presence_list is not None: - user_ids_to_check |= set( - UserID.from_string(p["observed_user_id"]) for p in presence_list - ) - for room_id in set(room_ids) & set(presence._room_serials): - if presence._room_serials[room_id] > from_key: - joined = yield presence.get_joined_users_for_room_id(room_id) - user_ids_to_check |= set(joined) - updates = [] - for observed_user in user_ids_to_check & set(cachemap): - cached = cachemap[observed_user] - - if cached.serial <= from_key or cached.serial > max_serial: - continue - - latest_serial = max(cached.serial, latest_serial) - updates.append(cached.make_event(user=observed_user, clock=clock)) + if not room_ids: + rooms = yield self.store.get_rooms_for_user(user_id) + room_ids = set(e.room_id for e in rooms) - # TODO(paul): limit - - for serial, user_ids in presence._remote_offline_serials: - if serial <= from_key: - break - - if serial > max_serial: - continue - - latest_serial = max(latest_serial, serial) - for u in user_ids: - updates.append({ - "type": "m.presence", - "content": {"user_id": u, "presence": PresenceState.OFFLINE}, - }) - # TODO(paul): For the v2 API we want to tell the client their from_key - # is too old if we fell off the end of the _remote_offline_serials - # list, and get them to invalidate+resync. In v1 we have no such - # concept so this is a best-effort result. - - if updates: - defer.returnValue((updates, latest_serial)) - else: - defer.returnValue(([], presence._user_cachemap_latest_serial)) - - def get_current_key(self): - presence = self.hs.get_handlers().presence_handler - return presence._user_cachemap_latest_serial + user_ids_to_check = set() + for room_id in room_ids: + users = yield self.store.get_users_in_room(room_id) + user_ids_to_check.update(users) - @defer.inlineCallbacks - def get_pagination_rows(self, user, pagination_config, key): - # TODO (erikj): Does this make sense? Ordering? + plist = yield self.store.get_presence_list_accepted(user.localpart) + user_ids_to_check.update([row["observed_user_id"] for row in plist]) - from_key = int(pagination_config.from_key) + # Always include yourself. Only really matters for when the user is + # not in any rooms, but still. + user_ids_to_check.add(user_id) - if pagination_config.to_key: - to_key = int(pagination_config.to_key) - else: - to_key = -1 + max_token = self.store.get_current_presence_token() - presence = self.hs.get_handlers().presence_handler - cachemap = presence._user_cachemap - - user_ids_to_check = {user} - presence_list = yield presence.store.get_presence_list( - user.localpart, accepted=True - ) - if presence_list is not None: - user_ids_to_check |= set( - UserID.from_string(p["observed_user_id"]) for p in presence_list + if from_key: + user_ids_changed = self.store.presence_stream_cache.get_entities_changed( + user_ids_to_check, from_key, ) - room_ids = yield presence.get_joined_rooms_for_user(user) - for room_id in set(room_ids) & set(presence._room_serials): - if presence._room_serials[room_id] >= from_key: - joined = yield presence.get_joined_users_for_room_id(room_id) - user_ids_to_check |= set(joined) - - updates = [] - for observed_user in user_ids_to_check & set(cachemap): - if not (to_key < cachemap[observed_user].serial <= from_key): - continue - - updates.append((observed_user, cachemap[observed_user])) - - # TODO(paul): limit - - if updates: - clock = self.clock - - earliest_serial = max([x[1].serial for x in updates]) - data = [x[1].make_event(user=x[0], clock=clock) for x in updates] - - defer.returnValue((data, earliest_serial)) else: - defer.returnValue(([], 0)) - + user_ids_changed = user_ids_to_check -class UserPresenceCache(object): - """Store an observed user's state and status message. + updates = yield presence.current_state_for_users(user_ids_changed) - Includes the update timestamp. - """ - def __init__(self): - self.state = {"presence": PresenceState.OFFLINE} - self.serial = None - - def __repr__(self): - return "UserPresenceCache(state=%r, serial=%r)" % ( - self.state, self.serial - ) - - def update(self, state, serial): - assert("mtime_age" not in state) + now = self.clock.time_msec() - self.state.update(state) - # Delete keys that are now 'None' - for k in self.state.keys(): - if self.state[k] is None: - del self.state[k] - - self.serial = serial - - if "status_msg" in state: - self.status_msg = state["status_msg"] - else: - self.status_msg = None - - def get_state(self): - # clone it so caller can't break our cache - state = dict(self.state) - return state - - def make_event(self, user, clock): - content = self.get_state() - content["user_id"] = user.to_string() + defer.returnValue(([ + { + "type": "m.presence", + "content": _format_user_presence_state(s, now), + } + for s in updates.values() + if include_offline or s.state != PresenceState.OFFLINE + ], max_token)) - if "last_active" in content: - content["last_active_ago"] = int( - clock.time_msec() - content.pop("last_active") - ) + def get_current_key(self): + return self.store.get_current_presence_token() - return {"type": "m.presence", "content": content} + def get_pagination_rows(self, user, pagination_config, key): + return self.get_new_events(user, from_key=None, include_offline=False) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 629e6e3594..7084a7396f 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -49,6 +49,9 @@ class ProfileHandler(BaseHandler): distributor = hs.get_distributor() self.distributor = distributor + distributor.declare("collect_presencelike_data") + distributor.declare("changed_presencelike_data") + distributor.observe("registered_user", self.registered_user) distributor.observe( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 1d0f0058a2..c5c13e085b 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -582,6 +582,28 @@ class SyncHandler(BaseHandler): if room_sync: joined.append(room_sync) + # For each newly joined room, we want to send down presence of + # existing users. + presence_handler = self.hs.get_handlers().presence_handler + extra_presence_users = set() + for room_id in newly_joined_rooms: + users = yield self.store.get_users_in_room(event.room_id) + extra_presence_users.update(users) + + # For each new member, send down presence. + for joined_sync in joined: + it = itertools.chain(joined_sync.timeline.events, joined_sync.state.values()) + for event in it: + if event.type == EventTypes.Member: + if event.membership == Membership.JOIN: + extra_presence_users.add(event.state_key) + + states = yield presence_handler.get_states( + [u for u in extra_presence_users if u != user_id], + as_event=True, + ) + presence.extend(states) + account_data_for_user = sync_config.filter_collection.filter_account_data( self.account_data_for_user(account_data) ) diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index a6f8754e32..27ea5f2a43 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -17,7 +17,7 @@ """ from twisted.internet import defer -from synapse.api.errors import SynapseError +from synapse.api.errors import SynapseError, AuthError from synapse.types import UserID from .base import ClientV1RestServlet, client_path_patterns @@ -35,8 +35,15 @@ class PresenceStatusRestServlet(ClientV1RestServlet): requester = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) - state = yield self.handlers.presence_handler.get_state( - target_user=user, auth_user=requester.user) + if requester.user != user: + allowed = yield self.handlers.presence_handler.is_visible( + observed_user=user, observer_user=requester.user, + ) + + if not allowed: + raise AuthError(403, "You are allowed to see their presence.") + + state = yield self.handlers.presence_handler.get_state(target_user=user) defer.returnValue((200, state)) @@ -45,6 +52,9 @@ class PresenceStatusRestServlet(ClientV1RestServlet): requester = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) + if requester.user != user: + raise AuthError(403, "Can only set your own presence state") + state = {} try: content = json.loads(request.content.read()) @@ -63,8 +73,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): except: raise SynapseError(400, "Unable to parse state") - yield self.handlers.presence_handler.set_state( - target_user=user, auth_user=requester.user, state=state) + yield self.handlers.presence_handler.set_state(user, state) defer.returnValue((200, {})) @@ -87,11 +96,8 @@ class PresenceListRestServlet(ClientV1RestServlet): raise SynapseError(400, "Cannot get another user's presence list") presence = yield self.handlers.presence_handler.get_presence_list( - observer_user=user, accepted=True) - - for p in presence: - observed_user = p.pop("observed_user") - p["user_id"] = observed_user.to_string() + observer_user=user, accepted=True + ) defer.returnValue((200, presence)) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 24706f9383..a8e89c7fe9 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -304,18 +304,6 @@ class RoomMemberListRestServlet(ClientV1RestServlet): if event["type"] != EventTypes.Member: continue chunk.append(event) - # FIXME: should probably be state_key here, not user_id - target_user = UserID.from_string(event["user_id"]) - # Presence is an optional cache; don't fail if we can't fetch it - try: - presence_handler = self.handlers.presence_handler - presence_state = yield presence_handler.get_state( - target_user=target_user, - auth_user=requester.user, - ) - event["content"].update(presence_state) - except: - pass defer.returnValue((200, { "chunk": chunk @@ -541,6 +529,10 @@ class RoomTypingRestServlet(ClientV1RestServlet): "/rooms/(?P[^/]*)/typing/(?P[^/]*)$" ) + def __init__(self, hs): + super(RoomTypingRestServlet, self).__init__(hs) + self.presence_handler = hs.get_handlers().presence_handler + @defer.inlineCallbacks def on_PUT(self, request, room_id, user_id): requester = yield self.auth.get_user_by_req(request) @@ -552,6 +544,8 @@ class RoomTypingRestServlet(ClientV1RestServlet): typing_handler = self.handlers.typing_notification_handler + yield self.presence_handler.bump_presence_active_time(requester.user) + if content["typing"]: yield typing_handler.started_typing( target_user=target_user, diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index eb4b369a3d..b831d8c95e 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -37,6 +37,7 @@ class ReceiptRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.receipts_handler = hs.get_handlers().receipts_handler + self.presence_handler = hs.get_handlers().presence_handler @defer.inlineCallbacks def on_POST(self, request, room_id, receipt_type, event_id): @@ -45,6 +46,8 @@ class ReceiptRestServlet(RestServlet): if receipt_type != "m.read": raise SynapseError(400, "Receipt type must be 'm.read'") + yield self.presence_handler.bump_presence_active_time(requester.user) + yield self.receipts_handler.received_client_receipt( room_id, receipt_type, diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index accbc6cfac..de4a020ad4 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -25,6 +25,7 @@ from synapse.events.utils import ( ) from synapse.api.filtering import FilterCollection, DEFAULT_FILTER_COLLECTION from synapse.api.errors import SynapseError +from synapse.api.constants import PresenceState from ._base import client_v2_patterns import copy @@ -82,6 +83,7 @@ class SyncRestServlet(RestServlet): self.sync_handler = hs.get_handlers().sync_handler self.clock = hs.get_clock() self.filtering = hs.get_filtering() + self.presence_handler = hs.get_handlers().presence_handler @defer.inlineCallbacks def on_GET(self, request): @@ -139,17 +141,19 @@ class SyncRestServlet(RestServlet): else: since_token = None - if set_presence == "online": - yield self.event_stream_handler.started_stream(user) + affect_presence = set_presence != PresenceState.OFFLINE - try: + if affect_presence: + yield self.presence_handler.set_state(user, {"presence": set_presence}) + + context = yield self.presence_handler.user_syncing( + user.to_string(), affect_presence=affect_presence, + ) + with context: sync_result = yield self.sync_handler.wait_for_sync_for_user( sync_config, since_token=since_token, timeout=timeout, full_state=full_state ) - finally: - if set_presence == "online": - self.event_stream_handler.stopped_stream(user) time_now = self.clock.time_msec() diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 5a9e7720d9..8c3cf9e801 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -20,7 +20,7 @@ from .appservice import ( from ._base import Cache from .directory import DirectoryStore from .events import EventsStore -from .presence import PresenceStore +from .presence import PresenceStore, UserPresenceState from .profile import ProfileStore from .registration import RegistrationStore from .room import RoomStore @@ -47,6 +47,7 @@ from .account_data import AccountDataStore from util.id_generators import IdGenerator, StreamIdGenerator +from synapse.api.constants import PresenceState from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -110,6 +111,9 @@ class DataStore(RoomMemberStore, RoomStore, self._account_data_id_gen = StreamIdGenerator( db_conn, "account_data_max_stream_id", "stream_id" ) + self._presence_id_gen = StreamIdGenerator( + db_conn, "presence_stream", "stream_id" + ) self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) self._state_groups_id_gen = IdGenerator("state_groups", "id", self) @@ -119,7 +123,7 @@ class DataStore(RoomMemberStore, RoomStore, self._push_rule_id_gen = IdGenerator("push_rules", "id", self) self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) - events_max = self._stream_id_gen.get_max_token(None) + events_max = self._stream_id_gen.get_max_token() event_cache_prefill, min_event_val = self._get_cache_dict( db_conn, "events", entity_column="room_id", @@ -135,13 +139,31 @@ class DataStore(RoomMemberStore, RoomStore, "MembershipStreamChangeCache", events_max, ) - account_max = self._account_data_id_gen.get_max_token(None) + account_max = self._account_data_id_gen.get_max_token() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max, ) + self.__presence_on_startup = self._get_active_presence(db_conn) + + presence_cache_prefill, min_presence_val = self._get_cache_dict( + db_conn, "presence_stream", + entity_column="user_id", + stream_column="stream_id", + max_value=self._presence_id_gen.get_max_token(), + ) + self.presence_stream_cache = StreamChangeCache( + "PresenceStreamChangeCache", min_presence_val, + prefilled_cache=presence_cache_prefill + ) + super(DataStore, self).__init__(hs) + def take_presence_startup_info(self): + active_on_startup = self.__presence_on_startup + self.__presence_on_startup = None + return active_on_startup + def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value): # Fetch a mapping of room_id -> max stream position for "recent" rooms. # It doesn't really matter how many we get, the StreamChangeCache will @@ -161,6 +183,7 @@ class DataStore(RoomMemberStore, RoomStore, txn = db_conn.cursor() txn.execute(sql, (int(max_value),)) rows = txn.fetchall() + txn.close() cache = { row[0]: int(row[1]) @@ -174,6 +197,27 @@ class DataStore(RoomMemberStore, RoomStore, return cache, min_val + def _get_active_presence(self, db_conn): + """Fetch non-offline presence from the database so that we can register + the appropriate time outs. + """ + + sql = ( + "SELECT user_id, state, last_active, last_federation_update," + " last_user_sync, status_msg, currently_active FROM presence_stream" + " WHERE state != ?" + ) + sql = self.database_engine.convert_param_style(sql) + + txn = db_conn.cursor() + txn.execute(sql, (PresenceState.OFFLINE,)) + rows = self.cursor_to_dict(txn) + + for row in rows: + row["currently_active"] = bool(row["currently_active"]) + + return [UserPresenceState(**row) for row in rows] + @defer.inlineCallbacks def insert_client_ip(self, user, access_token, ip, user_agent): now = int(self._clock.time_msec()) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 850736c85e..0fd5d497ab 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 29 +SCHEMA_VERSION = 30 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index ef525f34c5..b133979102 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -14,73 +14,128 @@ # limitations under the License. from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached, cachedList +from synapse.api.constants import PresenceState +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from collections import namedtuple from twisted.internet import defer -class PresenceStore(SQLBaseStore): - def create_presence(self, user_localpart): - res = self._simple_insert( - table="presence", - values={"user_id": user_localpart}, - desc="create_presence", +class UserPresenceState(namedtuple("UserPresenceState", + ("user_id", "state", "last_active", "last_federation_update", + "last_user_sync", "status_msg", "currently_active"))): + """Represents the current presence state of the user. + + user_id (str) + last_active (int): Time in msec that the user last interacted with server. + last_federation_update (int): Time in msec since either a) we sent a presence + update to other servers or b) we received a presence update, depending + on if is a local user or not. + last_user_sync (int): Time in msec that the user last *completed* a sync + (or event stream). + status_msg (str): User set status message. + """ + + def copy_and_replace(self, **kwargs): + return self._replace(**kwargs) + + @classmethod + def default(cls, user_id): + """Returns a default presence state. + """ + return cls( + user_id=user_id, + state=PresenceState.OFFLINE, + last_active=0, + last_federation_update=0, + last_user_sync=0, + status_msg=None, + currently_active=False, ) - self.get_presence_state.invalidate((user_localpart,)) - return res - def has_presence_state(self, user_localpart): - return self._simple_select_one( - table="presence", - keyvalues={"user_id": user_localpart}, - retcols=["user_id"], - allow_none=True, - desc="has_presence_state", +class PresenceStore(SQLBaseStore): + @defer.inlineCallbacks + def update_presence(self, presence_states): + stream_id_manager = yield self._presence_id_gen.get_next(self) + with stream_id_manager as stream_id: + yield self.runInteraction( + "update_presence", + self._update_presence_txn, stream_id, presence_states, + ) + + defer.returnValue((stream_id, self._presence_id_gen.get_max_token())) + + def _update_presence_txn(self, txn, stream_id, presence_states): + for state in presence_states: + txn.call_after( + self.presence_stream_cache.entity_has_changed, + state.user_id, stream_id, + ) + + # Actually insert new rows + self._simple_insert_many_txn( + txn, + table="presence_stream", + values=[ + { + "stream_id": stream_id, + "user_id": state.user_id, + "state": state.state, + "last_active": state.last_active, + "last_federation_update": state.last_federation_update, + "last_user_sync": state.last_user_sync, + "status_msg": state.status_msg, + "currently_active": state.currently_active, + } + for state in presence_states + ], ) - @cached(max_entries=2000) - def get_presence_state(self, user_localpart): - return self._simple_select_one( - table="presence", - keyvalues={"user_id": user_localpart}, - retcols=["state", "status_msg", "mtime"], - desc="get_presence_state", + # Delete old rows to stop database from getting really big + sql = ( + "DELETE FROM presence_stream WHERE" + " stream_id < ?" + " AND user_id IN (%s)" ) - @cachedList(get_presence_state.cache, list_name="user_localparts", - inlineCallbacks=True) - def get_presence_states(self, user_localparts): + batches = ( + presence_states[i:i + 50] + for i in xrange(0, len(presence_states), 50) + ) + for states in batches: + args = [stream_id] + args.extend(s.user_id for s in states) + txn.execute( + sql % (",".join("?" for _ in states),), + args + ) + + @defer.inlineCallbacks + def get_presence_for_users(self, user_ids): rows = yield self._simple_select_many_batch( - table="presence", + table="presence_stream", column="user_id", - iterable=user_localparts, - retcols=("user_id", "state", "status_msg", "mtime",), - desc="get_presence_states", + iterable=user_ids, + keyvalues={}, + retcols=( + "user_id", + "state", + "last_active", + "last_federation_update", + "last_user_sync", + "status_msg", + "currently_active", + ), ) - defer.returnValue({ - row["user_id"]: { - "state": row["state"], - "status_msg": row["status_msg"], - "mtime": row["mtime"], - } - for row in rows - }) + for row in rows: + row["currently_active"] = bool(row["currently_active"]) - @defer.inlineCallbacks - def set_presence_state(self, user_localpart, new_state): - res = yield self._simple_update_one( - table="presence", - keyvalues={"user_id": user_localpart}, - updatevalues={"state": new_state["state"], - "status_msg": new_state["status_msg"], - "mtime": self._clock.time_msec()}, - desc="set_presence_state", - ) + defer.returnValue([UserPresenceState(**row) for row in rows]) - self.get_presence_state.invalidate((user_localpart,)) - defer.returnValue(res) + def get_current_presence_token(self): + return self._presence_id_gen.get_max_token() def allow_presence_visible(self, observed_localpart, observer_userid): return self._simple_insert( @@ -128,6 +183,7 @@ class PresenceStore(SQLBaseStore): desc="set_presence_list_accepted", ) self.get_presence_list_accepted.invalidate((observer_localpart,)) + self.get_presence_list_observers_accepted.invalidate((observed_userid,)) defer.returnValue(result) def get_presence_list(self, observer_localpart, accepted=None): @@ -154,6 +210,19 @@ class PresenceStore(SQLBaseStore): desc="get_presence_list_accepted", ) + @cachedInlineCallbacks() + def get_presence_list_observers_accepted(self, observed_userid): + user_localparts = yield self._simple_select_onecol( + table="presence_list", + keyvalues={"observed_user_id": observed_userid, "accepted": True}, + retcol="user_id", + desc="get_presence_list_accepted", + ) + + defer.returnValue([ + "@%s:%s" % (u, self.hs.hostname,) for u in user_localparts + ]) + @defer.inlineCallbacks def del_presence_list(self, observer_localpart, observed_userid): yield self._simple_delete_one( @@ -163,3 +232,4 @@ class PresenceStore(SQLBaseStore): desc="del_presence_list", ) self.get_presence_list_accepted.invalidate((observer_localpart,)) + self.get_presence_list_observers_accepted.invalidate((observed_userid,)) diff --git a/synapse/storage/schema/delta/30/presence_stream.sql b/synapse/storage/schema/delta/30/presence_stream.sql new file mode 100644 index 0000000000..14f5e3d30a --- /dev/null +++ b/synapse/storage/schema/delta/30/presence_stream.sql @@ -0,0 +1,30 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + CREATE TABLE presence_stream( + stream_id BIGINT, + user_id TEXT, + state TEXT, + last_active BIGINT, + last_federation_update BIGINT, + last_user_sync BIGINT, + status_msg TEXT, + currently_active BOOLEAN + ); + + CREATE INDEX presence_stream_id ON presence_stream(stream_id, user_id); + CREATE INDEX presence_stream_user_id ON presence_stream(user_id); + CREATE INDEX presence_stream_state ON presence_stream(state); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 5c522f4ab9..5ce54f76de 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -130,9 +130,11 @@ class StreamIdGenerator(object): return manager() - def get_max_token(self, store): + def get_max_token(self, *args): """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. + + Used to take a DataStore param, which is no longer needed. """ with self._lock: if self._unfinished_ids: diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 133671e238..3b9da5b34a 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -42,7 +42,7 @@ class Clock(object): def time_msec(self): """Returns the current system time in miliseconds since epoch.""" - return self.time() * 1000 + return int(self.time() * 1000) def looping_call(self, f, msec): l = task.LoopingCall(f) diff --git a/tests/utils.py b/tests/utils.py index 3b1eb50d8d..f71125042b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -224,12 +224,12 @@ class MockClock(object): def time_msec(self): return self.time() * 1000 - def call_later(self, delay, callback): + def call_later(self, delay, callback, *args, **kwargs): current_context = LoggingContext.current_context() def wrapped_callback(): LoggingContext.thread_local.current_context = current_context - callback() + callback(*args, **kwargs) t = [self.now + delay, wrapped_callback, False] self.timers.append(t) -- cgit 1.4.1 From 700487a7c72c769537fb04c56d2e9c6dd29d7f84 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 19 Feb 2016 15:34:38 +0000 Subject: Fix flake8 warnings for tests --- tests/__init__.py | 1 - tests/api/test_auth.py | 4 +- tests/api/test_filtering.py | 12 +- tests/api/test_ratelimiting.py | 1 + tests/appservice/test_appservice.py | 2 +- tests/appservice/test_scheduler.py | 4 +- tests/crypto/__init__.py | 1 - tests/events/test_utils.py | 1 + tests/handlers/test_appservice.py | 2 - tests/handlers/test_auth.py | 1 - tests/handlers/test_profile.py | 9 +- tests/handlers/test_typing.py | 115 +++++++-------- tests/rest/client/v1/__init__.py | 1 - tests/rest/client/v1/test_profile.py | 56 ++++--- tests/rest/client/v1/test_rooms.py | 220 ++++++++++++++-------------- tests/rest/client/v1/test_typing.py | 44 +++--- tests/rest/client/v1/utils.py | 5 +- tests/rest/client/v2_alpha/__init__.py | 2 +- tests/rest/client/v2_alpha/test_filter.py | 19 +-- tests/rest/client/v2_alpha/test_register.py | 4 +- tests/storage/event_injector.py | 8 +- tests/storage/test__base.py | 8 +- tests/storage/test_background_update.py | 8 +- tests/storage/test_base.py | 89 ++++++----- tests/storage/test_events.py | 1 - tests/storage/test_profile.py | 2 +- tests/storage/test_registration.py | 6 +- tests/storage/test_room.py | 14 +- tests/storage/test_roommember.py | 22 ++- tests/storage/test_stream.py | 4 +- tests/test_distributor.py | 11 +- tests/test_test_utils.py | 3 +- tests/unittest.py | 7 +- tests/util/__init__.py | 1 - tests/util/test_dict_cache.py | 1 - tests/util/test_snapshot_cache.py | 1 + tests/util/test_treecache.py | 1 + tests/utils.py | 30 ++-- tox.ini | 2 +- 39 files changed, 359 insertions(+), 364 deletions(-) (limited to 'tests/utils.py') diff --git a/tests/__init__.py b/tests/__init__.py index d0e9399dda..bfebb0f644 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -12,4 +12,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 474c5c418f..7e7b0b4b1d 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -281,9 +281,9 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("user_id = %s" % (user,)) - macaroon.add_first_party_caveat("time < 1") # ms + macaroon.add_first_party_caveat("time < 1") # ms - self.hs.clock.now = 5000 # seconds + self.hs.clock.now = 5000 # seconds yield self.auth.get_user_from_macaroon(macaroon.serialize()) # TODO(daniel): Turn on the check that we validate expiration, when we diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index ceb0089268..4297a89f85 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -21,7 +21,6 @@ from tests.utils import ( MockHttpResource, DeferredMockCallable, setup_test_homeserver ) -from synapse.types import UserID from synapse.api.filtering import Filter from synapse.events import FrozenEvent @@ -356,7 +355,6 @@ class FilteringTestCase(unittest.TestCase): "types": ["m.*"] } } - user = UserID.from_string("@" + user_localpart + ":test") filter_id = yield self.datastore.add_user_filter( user_localpart=user_localpart, user_filter=user_filter_json, @@ -411,7 +409,6 @@ class FilteringTestCase(unittest.TestCase): } } } - user = UserID.from_string("@" + user_localpart + ":test") filter_id = yield self.datastore.add_user_filter( user_localpart=user_localpart, user_filter=user_filter_json, @@ -440,7 +437,6 @@ class FilteringTestCase(unittest.TestCase): } } } - user = UserID.from_string("@" + user_localpart + ":test") filter_id = yield self.datastore.add_user_filter( user_localpart=user_localpart, user_filter=user_filter_json, @@ -476,12 +472,12 @@ class FilteringTestCase(unittest.TestCase): ) self.assertEquals(filter_id, 0) - self.assertEquals(user_filter_json, - (yield self.datastore.get_user_filter( + self.assertEquals(user_filter_json, ( + yield self.datastore.get_user_filter( user_localpart=user_localpart, filter_id=0, - )) - ) + ) + )) @defer.inlineCallbacks def test_get_filter(self): diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index dd0bc19ecf..c45b59b36c 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -2,6 +2,7 @@ from synapse.api.ratelimiting import Ratelimiter from tests import unittest + class TestRatelimiter(unittest.TestCase): def test_allowed(self): diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index ef48bbc296..d6cc1881e9 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -14,7 +14,7 @@ # limitations under the License. from synapse.appservice import ApplicationService -from mock import Mock, PropertyMock +from mock import Mock from tests import unittest diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index c9c2d36210..631a229332 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.appservice import ApplicationServiceState, AppServiceTransaction +from synapse.appservice import ApplicationServiceState from synapse.appservice.scheduler import ( _ServiceQueuer, _TransactionController, _Recoverer ) @@ -235,7 +235,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): srv_2_event2 = Mock(event_id="srv2b") send_return_list = [srv_1_defer, srv_2_defer] - self.txn_ctrl.send = Mock(side_effect=lambda x,y: send_return_list.pop(0)) + self.txn_ctrl.send = Mock(side_effect=lambda x, y: send_return_list.pop(0)) # send events for different ASes and make sure they are sent self.queuer.enqueue(srv1, srv_1_event) diff --git a/tests/crypto/__init__.py b/tests/crypto/__init__.py index d0e9399dda..bfebb0f644 100644 --- a/tests/crypto/__init__.py +++ b/tests/crypto/__init__.py @@ -12,4 +12,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 894d0c3845..fb0953c4ec 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -19,6 +19,7 @@ from .. import unittest from synapse.events import FrozenEvent from synapse.events.utils import prune_event + class PruneEventTestCase(unittest.TestCase): """ Asserts that a new event constructed with `evdict` will look like `matchdict` when it is redacted. """ diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index ba6e2c640e..7ddbbb9b4a 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -129,8 +129,6 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.assertEquals(result.room_id, room_id) self.assertEquals(result.servers, servers) - - def _mkservice(self, is_interested): service = Mock() service.is_interested = Mock(return_value=is_interested) diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 2f21bf91e5..21077cbe9a 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -15,7 +15,6 @@ import pymacaroons -from mock import Mock, NonCallableMock from synapse.handlers.auth import AuthHandler from tests import unittest from tests.utils import setup_test_homeserver diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 95c87f0ebd..a87703bbfd 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -41,8 +41,10 @@ class ProfileTestCase(unittest.TestCase): ]) self.query_handlers = {} + def register_query_handler(query_type, handler): self.query_handlers[query_type] = handler + self.mock_federation.register_query_handler = register_query_handler hs = yield setup_test_homeserver( @@ -63,7 +65,7 @@ class ProfileTestCase(unittest.TestCase): self.store = hs.get_datastore() self.frank = UserID.from_string("@1234ABCD:test") - self.bob = UserID.from_string("@4567:test") + self.bob = UserID.from_string("@4567:test") self.alice = UserID.from_string("@alice:remote") yield self.store.create_profile(self.frank.localpart) @@ -133,8 +135,9 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_avatar(self): - yield self.handler.set_avatar_url(self.frank, self.frank, - "http://my.server/pic.gif") + yield self.handler.set_avatar_url( + self.frank, self.frank, "http://my.server/pic.gif" + ) self.assertEquals( (yield self.store.get_profile_avatar_url(self.frank.localpart)), diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 763c04d667..3955e7f5b1 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -138,9 +138,9 @@ class TypingNotificationsTestCase(unittest.TestCase): self.room_member_handler.get_joined_rooms_for_user = get_joined_rooms_for_user @defer.inlineCallbacks - def fetch_room_distributions_into(room_id, localusers=None, - remotedomains=None, ignore_user=None): - + def fetch_room_distributions_into( + room_id, localusers=None, remotedomains=None, ignore_user=None + ): members = yield get_room_members(room_id) for member in members: if ignore_user is not None and member == ignore_user: @@ -153,7 +153,8 @@ class TypingNotificationsTestCase(unittest.TestCase): if remotedomains is not None: remotedomains.add(member.domain) self.room_member_handler.fetch_room_distributions_into = ( - fetch_room_distributions_into) + fetch_room_distributions_into + ) def check_joined_room(room_id, user_id): if user_id not in [u.to_string() for u in self.room_members]: @@ -207,9 +208,12 @@ class TypingNotificationsTestCase(unittest.TestCase): put_json = self.mock_http_client.put_json put_json.expect_call_and_return( - call("farm", + call( + "farm", path="/_matrix/federation/v1/send/1000000/", - data=_expect_edu("farm", "m.typing", + data=_expect_edu( + "farm", + "m.typing", content={ "room_id": self.room_id, "user_id": self.u_apple.to_string(), @@ -237,9 +241,12 @@ class TypingNotificationsTestCase(unittest.TestCase): self.assertEquals(self.event_source.get_current_key(), 0) - yield self.mock_federation_resource.trigger("PUT", + yield self.mock_federation_resource.trigger( + "PUT", "/_matrix/federation/v1/send/1000000/", - _make_edu_json("farm", "m.typing", + _make_edu_json( + "farm", + "m.typing", content={ "room_id": self.room_id, "user_id": self.u_onion.to_string(), @@ -257,16 +264,13 @@ class TypingNotificationsTestCase(unittest.TestCase): room_ids=[self.room_id], from_key=0 ) - self.assertEquals( - events[0], - [ - {"type": "m.typing", - "room_id": self.room_id, - "content": { - "user_ids": [self.u_onion.to_string()], - }}, - ] - ) + self.assertEquals(events[0], [{ + "type": "m.typing", + "room_id": self.room_id, + "content": { + "user_ids": [self.u_onion.to_string()], + }, + }]) @defer.inlineCallbacks def test_stopped_typing(self): @@ -274,9 +278,12 @@ class TypingNotificationsTestCase(unittest.TestCase): put_json = self.mock_http_client.put_json put_json.expect_call_and_return( - call("farm", + call( + "farm", path="/_matrix/federation/v1/send/1000000/", - data=_expect_edu("farm", "m.typing", + data=_expect_edu( + "farm", + "m.typing", content={ "room_id": self.room_id, "user_id": self.u_apple.to_string(), @@ -317,16 +324,13 @@ class TypingNotificationsTestCase(unittest.TestCase): room_ids=[self.room_id], from_key=0, ) - self.assertEquals( - events[0], - [ - {"type": "m.typing", - "room_id": self.room_id, - "content": { - "user_ids": [], - }}, - ] - ) + self.assertEquals(events[0], [{ + "type": "m.typing", + "room_id": self.room_id, + "content": { + "user_ids": [], + }, + }]) @defer.inlineCallbacks def test_typing_timeout(self): @@ -351,16 +355,13 @@ class TypingNotificationsTestCase(unittest.TestCase): room_ids=[self.room_id], from_key=0, ) - self.assertEquals( - events[0], - [ - {"type": "m.typing", - "room_id": self.room_id, - "content": { - "user_ids": [self.u_apple.to_string()], - }}, - ] - ) + self.assertEquals(events[0], [{ + "type": "m.typing", + "room_id": self.room_id, + "content": { + "user_ids": [self.u_apple.to_string()], + }, + }]) self.clock.advance_time(11) @@ -373,16 +374,13 @@ class TypingNotificationsTestCase(unittest.TestCase): room_ids=[self.room_id], from_key=1, ) - self.assertEquals( - events[0], - [ - {"type": "m.typing", - "room_id": self.room_id, - "content": { - "user_ids": [], - }}, - ] - ) + self.assertEquals(events[0], [{ + "type": "m.typing", + "room_id": self.room_id, + "content": { + "user_ids": [], + }, + }]) # SYN-230 - see if we can still set after timeout @@ -403,13 +401,10 @@ class TypingNotificationsTestCase(unittest.TestCase): room_ids=[self.room_id], from_key=0, ) - self.assertEquals( - events[0], - [ - {"type": "m.typing", - "room_id": self.room_id, - "content": { - "user_ids": [self.u_apple.to_string()], - }}, - ] - ) + self.assertEquals(events[0], [{ + "type": "m.typing", + "room_id": self.room_id, + "content": { + "user_ids": [self.u_apple.to_string()], + }, + }]) diff --git a/tests/rest/client/v1/__init__.py b/tests/rest/client/v1/__init__.py index d0e9399dda..bfebb0f644 100644 --- a/tests/rest/client/v1/__init__.py +++ b/tests/rest/client/v1/__init__.py @@ -12,4 +12,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index c1a3f52043..0785965de2 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -65,8 +65,9 @@ class ProfileTestCase(unittest.TestCase): mocked_get = self.mock_handler.get_displayname mocked_get.return_value = defer.succeed("Frank") - (code, response) = yield self.mock_resource.trigger("GET", - "/profile/%s/displayname" % (myid), None) + (code, response) = yield self.mock_resource.trigger( + "GET", "/profile/%s/displayname" % (myid), None + ) self.assertEquals(200, code) self.assertEquals({"displayname": "Frank"}, response) @@ -77,9 +78,11 @@ class ProfileTestCase(unittest.TestCase): mocked_set = self.mock_handler.set_displayname mocked_set.return_value = defer.succeed(()) - (code, response) = yield self.mock_resource.trigger("PUT", - "/profile/%s/displayname" % (myid), - '{"displayname": "Frank Jr."}') + (code, response) = yield self.mock_resource.trigger( + "PUT", + "/profile/%s/displayname" % (myid), + '{"displayname": "Frank Jr."}' + ) self.assertEquals(200, code) self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD") @@ -91,19 +94,23 @@ class ProfileTestCase(unittest.TestCase): mocked_set = self.mock_handler.set_displayname mocked_set.side_effect = AuthError(400, "message") - (code, response) = yield self.mock_resource.trigger("PUT", - "/profile/%s/displayname" % ("@4567:test"), '"Frank Jr."') + (code, response) = yield self.mock_resource.trigger( + "PUT", "/profile/%s/displayname" % ("@4567:test"), '"Frank Jr."' + ) - self.assertTrue(400 <= code < 499, - msg="code %d is in the 4xx range" % (code)) + self.assertTrue( + 400 <= code < 499, + msg="code %d is in the 4xx range" % (code) + ) @defer.inlineCallbacks def test_get_other_name(self): mocked_get = self.mock_handler.get_displayname mocked_get.return_value = defer.succeed("Bob") - (code, response) = yield self.mock_resource.trigger("GET", - "/profile/%s/displayname" % ("@opaque:elsewhere"), None) + (code, response) = yield self.mock_resource.trigger( + "GET", "/profile/%s/displayname" % ("@opaque:elsewhere"), None + ) self.assertEquals(200, code) self.assertEquals({"displayname": "Bob"}, response) @@ -113,19 +120,23 @@ class ProfileTestCase(unittest.TestCase): mocked_set = self.mock_handler.set_displayname mocked_set.side_effect = SynapseError(400, "message") - (code, response) = yield self.mock_resource.trigger("PUT", - "/profile/%s/displayname" % ("@opaque:elsewhere"), None) + (code, response) = yield self.mock_resource.trigger( + "PUT", "/profile/%s/displayname" % ("@opaque:elsewhere"), None + ) - self.assertTrue(400 <= code <= 499, - msg="code %d is in the 4xx range" % (code)) + self.assertTrue( + 400 <= code <= 499, + msg="code %d is in the 4xx range" % (code) + ) @defer.inlineCallbacks def test_get_my_avatar(self): mocked_get = self.mock_handler.get_avatar_url mocked_get.return_value = defer.succeed("http://my.server/me.png") - (code, response) = yield self.mock_resource.trigger("GET", - "/profile/%s/avatar_url" % (myid), None) + (code, response) = yield self.mock_resource.trigger( + "GET", "/profile/%s/avatar_url" % (myid), None + ) self.assertEquals(200, code) self.assertEquals({"avatar_url": "http://my.server/me.png"}, response) @@ -136,12 +147,13 @@ class ProfileTestCase(unittest.TestCase): mocked_set = self.mock_handler.set_avatar_url mocked_set.return_value = defer.succeed(()) - (code, response) = yield self.mock_resource.trigger("PUT", - "/profile/%s/avatar_url" % (myid), - '{"avatar_url": "http://my.server/pic.gif"}') + (code, response) = yield self.mock_resource.trigger( + "PUT", + "/profile/%s/avatar_url" % (myid), + '{"avatar_url": "http://my.server/pic.gif"}' + ) self.assertEquals(200, code) self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD") self.assertEquals(mocked_set.call_args[0][1].localpart, "1234ABCD") - self.assertEquals(mocked_set.call_args[0][2], - "http://my.server/pic.gif") + self.assertEquals(mocked_set.call_args[0][2], "http://my.server/pic.gif") diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index a90b9dc3d8..afca5303ba 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -82,19 +82,22 @@ class RoomPermissionsTestCase(RestTestCase): is_public=True) # send a message in one of the rooms - self.created_rmid_msg_path = ("/rooms/%s/send/m.room.message/a1" % - (self.created_rmid)) + self.created_rmid_msg_path = ( + "/rooms/%s/send/m.room.message/a1" % (self.created_rmid) + ) (code, response) = yield self.mock_resource.trigger( - "PUT", - self.created_rmid_msg_path, - '{"msgtype":"m.text","body":"test msg"}') + "PUT", + self.created_rmid_msg_path, + '{"msgtype":"m.text","body":"test msg"}' + ) self.assertEquals(200, code, msg=str(response)) # set topic for public room (code, response) = yield self.mock_resource.trigger( - "PUT", - "/rooms/%s/state/m.room.topic" % self.created_public_rmid, - '{"topic":"Public Room Topic"}') + "PUT", + "/rooms/%s/state/m.room.topic" % self.created_public_rmid, + '{"topic":"Public Room Topic"}' + ) self.assertEquals(200, code, msg=str(response)) # auth as user_id now @@ -103,37 +106,6 @@ class RoomPermissionsTestCase(RestTestCase): def tearDown(self): pass -# @defer.inlineCallbacks -# def test_get_message(self): -# # get message in uncreated room, expect 403 -# (code, response) = yield self.mock_resource.trigger_get( -# "/rooms/noroom/messages/someid/m1") -# self.assertEquals(403, code, msg=str(response)) -# -# # get message in created room not joined (no state), expect 403 -# (code, response) = yield self.mock_resource.trigger_get( -# self.created_rmid_msg_path) -# self.assertEquals(403, code, msg=str(response)) -# -# # get message in created room and invited, expect 403 -# yield self.invite(room=self.created_rmid, src=self.rmcreator_id, -# targ=self.user_id) -# (code, response) = yield self.mock_resource.trigger_get( -# self.created_rmid_msg_path) -# self.assertEquals(403, code, msg=str(response)) -# -# # get message in created room and joined, expect 200 -# yield self.join(room=self.created_rmid, user=self.user_id) -# (code, response) = yield self.mock_resource.trigger_get( -# self.created_rmid_msg_path) -# self.assertEquals(200, code, msg=str(response)) -# -# # get message in created room and left, expect 403 -# yield self.leave(room=self.created_rmid, user=self.user_id) -# (code, response) = yield self.mock_resource.trigger_get( -# self.created_rmid_msg_path) -# self.assertEquals(403, code, msg=str(response)) - @defer.inlineCallbacks def test_send_message(self): msg_content = '{"msgtype":"m.text","body":"hello"}' @@ -195,25 +167,30 @@ class RoomPermissionsTestCase(RestTestCase): # set/get topic in uncreated room, expect 403 (code, response) = yield self.mock_resource.trigger( - "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, - topic_content) + "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, + topic_content + ) self.assertEquals(403, code, msg=str(response)) (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/state/m.room.topic" % self.uncreated_rmid) + "/rooms/%s/state/m.room.topic" % self.uncreated_rmid + ) self.assertEquals(403, code, msg=str(response)) # set/get topic in created PRIVATE room not joined, expect 403 (code, response) = yield self.mock_resource.trigger( - "PUT", topic_path, topic_content) + "PUT", topic_path, topic_content + ) self.assertEquals(403, code, msg=str(response)) (code, response) = yield self.mock_resource.trigger_get(topic_path) self.assertEquals(403, code, msg=str(response)) # set topic in created PRIVATE room and invited, expect 403 - yield self.invite(room=self.created_rmid, src=self.rmcreator_id, - targ=self.user_id) + yield self.invite( + room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id + ) (code, response) = yield self.mock_resource.trigger( - "PUT", topic_path, topic_content) + "PUT", topic_path, topic_content + ) self.assertEquals(403, code, msg=str(response)) # get topic in created PRIVATE room and invited, expect 403 @@ -226,7 +203,8 @@ class RoomPermissionsTestCase(RestTestCase): # Only room ops can set topic by default self.auth_user_id = self.rmcreator_id (code, response) = yield self.mock_resource.trigger( - "PUT", topic_path, topic_content) + "PUT", topic_path, topic_content + ) self.assertEquals(200, code, msg=str(response)) self.auth_user_id = self.user_id @@ -237,30 +215,31 @@ class RoomPermissionsTestCase(RestTestCase): # set/get topic in created PRIVATE room and left, expect 403 yield self.leave(room=self.created_rmid, user=self.user_id) (code, response) = yield self.mock_resource.trigger( - "PUT", topic_path, topic_content) + "PUT", topic_path, topic_content + ) self.assertEquals(403, code, msg=str(response)) (code, response) = yield self.mock_resource.trigger_get(topic_path) self.assertEquals(200, code, msg=str(response)) # get topic in PUBLIC room, not joined, expect 403 (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/state/m.room.topic" % self.created_public_rmid) + "/rooms/%s/state/m.room.topic" % self.created_public_rmid + ) self.assertEquals(403, code, msg=str(response)) # set topic in PUBLIC room, not joined, expect 403 (code, response) = yield self.mock_resource.trigger( - "PUT", - "/rooms/%s/state/m.room.topic" % self.created_public_rmid, - topic_content) + "PUT", + "/rooms/%s/state/m.room.topic" % self.created_public_rmid, + topic_content + ) self.assertEquals(403, code, msg=str(response)) @defer.inlineCallbacks def _test_get_membership(self, room=None, members=[], expect_code=None): - path = "/rooms/%s/state/m.room.member/%s" for member in members: - (code, response) = yield self.mock_resource.trigger_get( - path % - (room, member)) + path = "/rooms/%s/state/m.room.member/%s" % (room, member) + (code, response) = yield self.mock_resource.trigger_get(path) self.assertEquals(expect_code, code) @defer.inlineCallbacks @@ -461,20 +440,23 @@ class RoomsMemberListTestCase(RestTestCase): def test_get_member_list(self): room_id = yield self.create_room_as(self.user_id) (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/members" % room_id) + "/rooms/%s/members" % room_id + ) self.assertEquals(200, code, msg=str(response)) @defer.inlineCallbacks def test_get_member_list_no_room(self): (code, response) = yield self.mock_resource.trigger_get( - "/rooms/roomdoesnotexist/members") + "/rooms/roomdoesnotexist/members" + ) self.assertEquals(403, code, msg=str(response)) @defer.inlineCallbacks def test_get_member_list_no_permission(self): room_id = yield self.create_room_as("@some_other_guy:red") (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/members" % room_id) + "/rooms/%s/members" % room_id + ) self.assertEquals(403, code, msg=str(response)) @defer.inlineCallbacks @@ -636,34 +618,41 @@ class RoomTopicTestCase(RestTestCase): @defer.inlineCallbacks def test_invalid_puts(self): # missing keys or invalid json - (code, response) = yield self.mock_resource.trigger("PUT", - self.path, '{}') + (code, response) = yield self.mock_resource.trigger( + "PUT", self.path, '{}' + ) self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - self.path, '{"_name":"bob"}') + (code, response) = yield self.mock_resource.trigger( + "PUT", self.path, '{"_name":"bob"}' + ) self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - self.path, '{"nao') + (code, response) = yield self.mock_resource.trigger( + "PUT", self.path, '{"nao' + ) self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - self.path, '[{"_name":"bob"},{"_name":"jill"}]') + (code, response) = yield self.mock_resource.trigger( + "PUT", self.path, '[{"_name":"bob"},{"_name":"jill"}]' + ) self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - self.path, 'text only') + (code, response) = yield self.mock_resource.trigger( + "PUT", self.path, 'text only' + ) self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - self.path, '') + (code, response) = yield self.mock_resource.trigger( + "PUT", self.path, '' + ) self.assertEquals(400, code, msg=str(response)) # valid key, wrong type content = '{"topic":["Topic name"]}' - (code, response) = yield self.mock_resource.trigger("PUT", - self.path, content) + (code, response) = yield self.mock_resource.trigger( + "PUT", self.path, content + ) self.assertEquals(400, code, msg=str(response)) @defer.inlineCallbacks @@ -674,8 +663,9 @@ class RoomTopicTestCase(RestTestCase): # valid put content = '{"topic":"Topic name"}' - (code, response) = yield self.mock_resource.trigger("PUT", - self.path, content) + (code, response) = yield self.mock_resource.trigger( + "PUT", self.path, content + ) self.assertEquals(200, code, msg=str(response)) # valid get @@ -687,8 +677,9 @@ class RoomTopicTestCase(RestTestCase): def test_rooms_topic_with_extra_keys(self): # valid put with extra keys content = '{"topic":"Seasons","subtopic":"Summer"}' - (code, response) = yield self.mock_resource.trigger("PUT", - self.path, content) + (code, response) = yield self.mock_resource.trigger( + "PUT", self.path, content + ) self.assertEquals(200, code, msg=str(response)) # valid get @@ -740,33 +731,38 @@ class RoomMemberStateTestCase(RestTestCase): def test_invalid_puts(self): path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) # missing keys or invalid json - (code, response) = yield self.mock_resource.trigger("PUT", - path, '{}') + (code, response) = yield self.mock_resource.trigger("PUT", path, '{}') self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - path, '{"_name":"bob"}') + (code, response) = yield self.mock_resource.trigger( + "PUT", path, '{"_name":"bob"}' + ) self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - path, '{"nao') + (code, response) = yield self.mock_resource.trigger( + "PUT", path, '{"nao' + ) self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - path, '[{"_name":"bob"},{"_name":"jill"}]') + (code, response) = yield self.mock_resource.trigger( + "PUT", path, '[{"_name":"bob"},{"_name":"jill"}]' + ) self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - path, 'text only') + (code, response) = yield self.mock_resource.trigger( + "PUT", path, 'text only' + ) self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - path, '') + (code, response) = yield self.mock_resource.trigger( + "PUT", path, '' + ) self.assertEquals(400, code, msg=str(response)) # valid keys, wrong types - content = ('{"membership":["%s","%s","%s"]}' % - (Membership.INVITE, Membership.JOIN, Membership.LEAVE)) + content = ('{"membership":["%s","%s","%s"]}' % ( + Membership.INVITE, Membership.JOIN, Membership.LEAVE + )) (code, response) = yield self.mock_resource.trigger("PUT", path, content) self.assertEquals(400, code, msg=str(response)) @@ -813,8 +809,9 @@ class RoomMemberStateTestCase(RestTestCase): ) # valid invite message with custom key - content = ('{"membership":"%s","invite_text":"%s"}' % - (Membership.INVITE, "Join us!")) + content = ('{"membership":"%s","invite_text":"%s"}' % ( + Membership.INVITE, "Join us!" + )) (code, response) = yield self.mock_resource.trigger("PUT", path, content) self.assertEquals(200, code, msg=str(response)) @@ -867,28 +864,34 @@ class RoomMessagesTestCase(RestTestCase): path = "/rooms/%s/send/m.room.message/mid1" % ( urllib.quote(self.room_id)) # missing keys or invalid json - (code, response) = yield self.mock_resource.trigger("PUT", - path, '{}') + (code, response) = yield self.mock_resource.trigger( + "PUT", path, '{}' + ) self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - path, '{"_name":"bob"}') + (code, response) = yield self.mock_resource.trigger( + "PUT", path, '{"_name":"bob"}' + ) self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - path, '{"nao') + (code, response) = yield self.mock_resource.trigger( + "PUT", path, '{"nao' + ) self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - path, '[{"_name":"bob"},{"_name":"jill"}]') + (code, response) = yield self.mock_resource.trigger( + "PUT", path, '[{"_name":"bob"},{"_name":"jill"}]' + ) self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - path, 'text only') + (code, response) = yield self.mock_resource.trigger( + "PUT", path, 'text only' + ) self.assertEquals(400, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger("PUT", - path, '') + (code, response) = yield self.mock_resource.trigger( + "PUT", path, '' + ) self.assertEquals(400, code, msg=str(response)) @defer.inlineCallbacks @@ -959,7 +962,8 @@ class RoomInitialSyncTestCase(RestTestCase): @defer.inlineCallbacks def test_initial_sync(self): (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/initialSync" % self.room_id) + "/rooms/%s/initialSync" % self.room_id + ) self.assertEquals(200, code) self.assertEquals(self.room_id, response["room_id"]) @@ -983,8 +987,8 @@ class RoomInitialSyncTestCase(RestTestCase): self.assertTrue("presence" in response) - presence_by_user = {e["content"]["user_id"]: e - for e in response["presence"] + presence_by_user = { + e["content"]["user_id"]: e for e in response["presence"] } self.assertTrue(self.user_id in presence_by_user) self.assertEquals("m.presence", presence_by_user[self.user_id]["type"]) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index c4ac181a33..16d788ff61 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -81,9 +81,9 @@ class RoomTypingTestCase(RestTestCase): return defer.succeed([]) @defer.inlineCallbacks - def fetch_room_distributions_into(room_id, localusers=None, - remotedomains=None, ignore_user=None): - + def fetch_room_distributions_into( + room_id, localusers=None, remotedomains=None, ignore_user=None + ): members = yield get_room_members(room_id) for member in members: if ignore_user is not None and member == ignore_user: @@ -96,7 +96,8 @@ class RoomTypingTestCase(RestTestCase): if remotedomains is not None: remotedomains.add(member.domain) hs.get_handlers().room_member_handler.fetch_room_distributions_into = ( - fetch_room_distributions_into) + fetch_room_distributions_into + ) synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) @@ -109,8 +110,8 @@ class RoomTypingTestCase(RestTestCase): @defer.inlineCallbacks def test_set_typing(self): - (code, _) = yield self.mock_resource.trigger("PUT", - "/rooms/%s/typing/%s" % (self.room_id, self.user_id), + (code, _) = yield self.mock_resource.trigger( + "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), '{"typing": true, "timeout": 30000}' ) self.assertEquals(200, code) @@ -120,41 +121,38 @@ class RoomTypingTestCase(RestTestCase): from_key=0, room_ids=[self.room_id], ) - self.assertEquals( - events[0], - [ - {"type": "m.typing", - "room_id": self.room_id, - "content": { - "user_ids": [self.user_id], - }}, - ] - ) + self.assertEquals(events[0], [{ + "type": "m.typing", + "room_id": self.room_id, + "content": { + "user_ids": [self.user_id], + } + }]) @defer.inlineCallbacks def test_set_not_typing(self): - (code, _) = yield self.mock_resource.trigger("PUT", - "/rooms/%s/typing/%s" % (self.room_id, self.user_id), + (code, _) = yield self.mock_resource.trigger( + "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), '{"typing": false}' ) self.assertEquals(200, code) @defer.inlineCallbacks def test_typing_timeout(self): - (code, _) = yield self.mock_resource.trigger("PUT", - "/rooms/%s/typing/%s" % (self.room_id, self.user_id), + (code, _) = yield self.mock_resource.trigger( + "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), '{"typing": true, "timeout": 30000}' ) self.assertEquals(200, code) self.assertEquals(self.event_source.get_current_key(), 1) - self.clock.advance_time(31); + self.clock.advance_time(31) self.assertEquals(self.event_source.get_current_key(), 2) - (code, _) = yield self.mock_resource.trigger("PUT", - "/rooms/%s/typing/%s" % (self.room_id, self.user_id), + (code, _) = yield self.mock_resource.trigger( + "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), '{"typing": true, "timeout": 30000}' ) self.assertEquals(200, code) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index af376804f6..17524b2e23 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -84,8 +84,9 @@ class RestTestCase(unittest.TestCase): "membership": membership } - (code, response) = yield self.mock_resource.trigger("PUT", path, - json.dumps(data)) + (code, response) = yield self.mock_resource.trigger( + "PUT", path, json.dumps(data) + ) self.assertEquals(expect_code, code, msg=str(response)) self.auth_user_id = temp_id diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py index 16dce6c723..84334dce34 100644 --- a/tests/rest/client/v2_alpha/__init__.py +++ b/tests/rest/client/v2_alpha/__init__.py @@ -55,7 +55,7 @@ class V2AlphaRestTestCase(unittest.TestCase): r.register_servlets(hs, self.mock_resource) def make_datastore_mock(self): - store = Mock(spec=[ + store = Mock(spec=[ "insert_client_ip", ]) store.get_app_service_by_token = Mock(return_value=None) diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index c86e904c8e..d1442aafac 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -15,8 +15,6 @@ from twisted.internet import defer -from mock import Mock - from . import V2AlphaRestTestCase from synapse.rest.client.v2_alpha import filter @@ -53,9 +51,8 @@ class FilterTestCase(V2AlphaRestTestCase): @defer.inlineCallbacks def test_add_filter(self): - (code, response) = yield self.mock_resource.trigger("POST", - "/user/%s/filter" % (self.USER_ID), - '{"type": ["m.*"]}' + (code, response) = yield self.mock_resource.trigger( + "POST", "/user/%s/filter" % (self.USER_ID), '{"type": ["m.*"]}' ) self.assertEquals(200, code) self.assertEquals({"filter_id": "0"}, response) @@ -70,8 +67,8 @@ class FilterTestCase(V2AlphaRestTestCase): {"type": ["m.*"]} ] - (code, response) = yield self.mock_resource.trigger("GET", - "/user/%s/filter/0" % (self.USER_ID), None + (code, response) = yield self.mock_resource.trigger_get( + "/user/%s/filter/0" % (self.USER_ID) ) self.assertEquals(200, code) self.assertEquals({"type": ["m.*"]}, response) @@ -82,14 +79,14 @@ class FilterTestCase(V2AlphaRestTestCase): {"type": ["m.*"]} ] - (code, response) = yield self.mock_resource.trigger("GET", - "/user/%s/filter/2" % (self.USER_ID), None + (code, response) = yield self.mock_resource.trigger_get( + "/user/%s/filter/2" % (self.USER_ID) ) self.assertEquals(404, code) @defer.inlineCallbacks def test_get_filter_no_user(self): - (code, response) = yield self.mock_resource.trigger("GET", - "/user/%s/filter/0" % (self.USER_ID), None + (code, response) = yield self.mock_resource.trigger_get( + "/user/%s/filter/0" % (self.USER_ID) ) self.assertEquals(404, code) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index df0841b0b1..b867599079 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -1,7 +1,7 @@ from synapse.rest.client.v2_alpha.register import RegisterRestServlet from synapse.api.errors import SynapseError from twisted.internet import defer -from mock import Mock, MagicMock +from mock import Mock from tests import unittest import json @@ -24,7 +24,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.auth_result = (False, None, None) self.auth_handler = Mock( - check_auth=Mock(side_effect=lambda x,y,z: self.auth_result) + check_auth=Mock(side_effect=lambda x, y, z: self.auth_result) ) self.registration_handler = Mock() self.identity_handler = Mock() diff --git a/tests/storage/event_injector.py b/tests/storage/event_injector.py index dca785eb27..f22ba8db89 100644 --- a/tests/storage/event_injector.py +++ b/tests/storage/event_injector.py @@ -14,15 +14,9 @@ # limitations under the License. -from tests import unittest from twisted.internet import defer -from synapse.api.constants import EventTypes, Membership -from synapse.types import UserID, RoomID - -from tests.utils import setup_test_homeserver - -from mock import Mock +from synapse.api.constants import EventTypes class EventInjector: diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 219288621d..96b7dba5fe 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -174,11 +174,13 @@ class CacheDecoratorTestCase(unittest.TestCase): # There must have been at least 2 evictions, meaning if we calculate # all 12 values again, we must get called at least 2 more times - for k in range(0,12): + for k in range(0, 12): yield a.func(k) - self.assertTrue(callcount[0] >= 14, - msg="Expected callcount >= 14, got %d" % (callcount[0])) + self.assertTrue( + callcount[0] >= 14, + msg="Expected callcount >= 14, got %d" % (callcount[0]) + ) def test_prefill(self): callcount = [0] diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 29289fa9b4..6e4d9b1373 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -1,13 +1,11 @@ from tests import unittest from twisted.internet import defer -from synapse.api.constants import EventTypes -from synapse.types import UserID, RoomID, RoomAlias - from tests.utils import setup_test_homeserver from mock import Mock + class BackgroundUpdateTestCase(unittest.TestCase): @defer.inlineCallbacks @@ -24,8 +22,8 @@ class BackgroundUpdateTestCase(unittest.TestCase): @defer.inlineCallbacks def test_do_background_update(self): - desired_count = 1000; - duration_ms = 42; + desired_count = 1000 + duration_ms = 42 @defer.inlineCallbacks def update(progress, count): diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 152d027663..c76545be65 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -17,7 +17,7 @@ from tests import unittest from twisted.internet import defer -from mock import Mock, call +from mock import Mock from collections import OrderedDict @@ -62,13 +62,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 yield self.datastore._simple_insert( - table="tablename", - values={"columname": "Value"} + table="tablename", + values={"columname": "Value"} ) self.mock_txn.execute.assert_called_with( - "INSERT INTO tablename (columname) VALUES(?)", - ("Value",) + "INSERT INTO tablename (columname) VALUES(?)", ("Value",) ) @defer.inlineCallbacks @@ -76,14 +75,14 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 yield self.datastore._simple_insert( - table="tablename", - # Use OrderedDict() so we can assert on the SQL generated - values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]) + table="tablename", + # Use OrderedDict() so we can assert on the SQL generated + values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]) ) self.mock_txn.execute.assert_called_with( - "INSERT INTO tablename (colA, colB, colC) VALUES(?, ?, ?)", - (1, 2, 3,) + "INSERT INTO tablename (colA, colB, colC) VALUES(?, ?, ?)", + (1, 2, 3,) ) @defer.inlineCallbacks @@ -92,15 +91,14 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.fetchall.return_value = [("Value",)] value = yield self.datastore._simple_select_one_onecol( - table="tablename", - keyvalues={"keycol": "TheKey"}, - retcol="retcol" + table="tablename", + keyvalues={"keycol": "TheKey"}, + retcol="retcol" ) self.assertEquals("Value", value) self.mock_txn.execute.assert_called_with( - "SELECT retcol FROM tablename WHERE keycol = ?", - ["TheKey"] + "SELECT retcol FROM tablename WHERE keycol = ?", ["TheKey"] ) @defer.inlineCallbacks @@ -109,15 +107,15 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.fetchone.return_value = (1, 2, 3) ret = yield self.datastore._simple_select_one( - table="tablename", - keyvalues={"keycol": "TheKey"}, - retcols=["colA", "colB", "colC"] + table="tablename", + keyvalues={"keycol": "TheKey"}, + retcols=["colA", "colB", "colC"] ) self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret) self.mock_txn.execute.assert_called_with( - "SELECT colA, colB, colC FROM tablename WHERE keycol = ?", - ["TheKey"] + "SELECT colA, colB, colC FROM tablename WHERE keycol = ?", + ["TheKey"] ) @defer.inlineCallbacks @@ -126,32 +124,32 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.fetchone.return_value = None ret = yield self.datastore._simple_select_one( - table="tablename", - keyvalues={"keycol": "Not here"}, - retcols=["colA"], - allow_none=True + table="tablename", + keyvalues={"keycol": "Not here"}, + retcols=["colA"], + allow_none=True ) self.assertFalse(ret) @defer.inlineCallbacks def test_select_list(self): - self.mock_txn.rowcount = 3; + self.mock_txn.rowcount = 3 self.mock_txn.fetchall.return_value = ((1,), (2,), (3,)) self.mock_txn.description = ( - ("colA", None, None, None, None, None, None), + ("colA", None, None, None, None, None, None), ) ret = yield self.datastore._simple_select_list( - table="tablename", - keyvalues={"keycol": "A set"}, - retcols=["colA"], + table="tablename", + keyvalues={"keycol": "A set"}, + retcols=["colA"], ) self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret) self.mock_txn.execute.assert_called_with( - "SELECT colA FROM tablename WHERE keycol = ?", - ["A set"] + "SELECT colA FROM tablename WHERE keycol = ?", + ["A set"] ) @defer.inlineCallbacks @@ -159,14 +157,14 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 yield self.datastore._simple_update_one( - table="tablename", - keyvalues={"keycol": "TheKey"}, - updatevalues={"columnname": "New Value"} + table="tablename", + keyvalues={"keycol": "TheKey"}, + updatevalues={"columnname": "New Value"} ) self.mock_txn.execute.assert_called_with( - "UPDATE tablename SET columnname = ? WHERE keycol = ?", - ["New Value", "TheKey"] + "UPDATE tablename SET columnname = ? WHERE keycol = ?", + ["New Value", "TheKey"] ) @defer.inlineCallbacks @@ -174,15 +172,15 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 yield self.datastore._simple_update_one( - table="tablename", - keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), - updatevalues=OrderedDict([("colC", 3), ("colD", 4)]) + table="tablename", + keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), + updatevalues=OrderedDict([("colC", 3), ("colD", 4)]) ) self.mock_txn.execute.assert_called_with( - "UPDATE tablename SET colC = ?, colD = ? WHERE " + - "colA = ? AND colB = ?", - [3, 4, 1, 2] + "UPDATE tablename SET colC = ?, colD = ? WHERE" + " colA = ? AND colB = ?", + [3, 4, 1, 2] ) @defer.inlineCallbacks @@ -190,11 +188,10 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 yield self.datastore._simple_delete_one( - table="tablename", - keyvalues={"keycol": "Go away"}, + table="tablename", + keyvalues={"keycol": "Go away"}, ) self.mock_txn.execute.assert_called_with( - "DELETE FROM tablename WHERE keycol = ?", - ["Go away"] + "DELETE FROM tablename WHERE keycol = ?", ["Go away"] ) diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index 4aa82d4c9d..18a6cff0c7 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import uuid from mock import Mock from synapse.types import RoomID, UserID diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 47e2768b2c..24118bbc86 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -55,7 +55,7 @@ class ProfileStoreTestCase(unittest.TestCase): ) yield self.store.set_profile_avatar_url( - self.u_frank.localpart, "http://my.site/here" + self.u_frank.localpart, "http://my.site/here" ) self.assertEquals( diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 7b3b4c13bc..b8384c98d8 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -33,8 +33,10 @@ class RegistrationStoreTestCase(unittest.TestCase): self.store = hs.get_datastore() self.user_id = "@my-user:test" - self.tokens = ["AbCdEfGhIjKlMnOpQrStUvWxYz", - "BcDeFgHiJkLmNoPqRsTuVwXyZa"] + self.tokens = [ + "AbCdEfGhIjKlMnOpQrStUvWxYz", + "BcDeFgHiJkLmNoPqRsTuVwXyZa" + ] self.pwhash = "{xx1}123456789" @defer.inlineCallbacks diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 0baaf3df21..ef8a4d234f 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -37,7 +37,8 @@ class RoomStoreTestCase(unittest.TestCase): self.alias = RoomAlias.from_string("#a-room-name:test") self.u_creator = UserID.from_string("@creator:test") - yield self.store.store_room(self.room.to_string(), + yield self.store.store_room( + self.room.to_string(), room_creator_user_id=self.u_creator.to_string(), is_public=True ) @@ -45,9 +46,11 @@ class RoomStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_room(self): self.assertDictContainsSubset( - {"room_id": self.room.to_string(), - "creator": self.u_creator.to_string(), - "is_public": True}, + { + "room_id": self.room.to_string(), + "creator": self.u_creator.to_string(), + "is_public": True + }, (yield self.store.get_room(self.room.to_string())) ) @@ -65,7 +68,8 @@ class RoomEventsStoreTestCase(unittest.TestCase): self.room = RoomID.from_string("!abcde:test") - yield self.store.store_room(self.room.to_string(), + yield self.store.store_room( + self.room.to_string(), room_creator_user_id="@creator:text", is_public=True ) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index bab15c4165..677d11f68d 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -88,8 +88,8 @@ class RoomMemberStoreTestCase(unittest.TestCase): [m.room_id for m in ( yield self.store.get_rooms_for_user_where_membership_is( self.u_alice.to_string(), [Membership.JOIN] - )) - ] + ) + )] ) self.assertFalse( (yield self.store.user_rooms_intersect( @@ -108,11 +108,11 @@ class RoomMemberStoreTestCase(unittest.TestCase): yield self.store.get_room_members(self.room.to_string()) )} ) - self.assertTrue( - (yield self.store.user_rooms_intersect( - [self.u_alice.to_string(), self.u_bob.to_string()] - )) - ) + self.assertTrue(( + yield self.store.user_rooms_intersect([ + self.u_alice.to_string(), self.u_bob.to_string() + ]) + )) @defer.inlineCallbacks def test_room_hosts(self): @@ -136,9 +136,7 @@ class RoomMemberStoreTestCase(unittest.TestCase): self.assertEquals( {"test", "elsewhere"}, - (yield - self.store.get_joined_hosts_for_room(self.room.to_string()) - ) + (yield self.store.get_joined_hosts_for_room(self.room.to_string())) ) # Should still have both hosts @@ -146,9 +144,7 @@ class RoomMemberStoreTestCase(unittest.TestCase): self.assertEquals( {"test", "elsewhere"}, - (yield - self.store.get_joined_hosts_for_room(self.room.to_string()) - ) + (yield self.store.get_joined_hosts_for_room(self.room.to_string())) ) # Should have only one host after other leaves diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index 708208aff1..da322152c7 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -156,13 +156,13 @@ class StreamStoreTestCase(unittest.TestCase): self.room1, self.u_bob, Membership.JOIN ) - event1 = yield self.event_injector.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_alice, Membership.JOIN ) start = yield self.store.get_room_events_max_id() - event2 = yield self.event_injector.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_alice, Membership.JOIN, ) diff --git a/tests/test_distributor.py b/tests/test_distributor.py index a80f580ba6..acebcf4a86 100644 --- a/tests/test_distributor.py +++ b/tests/test_distributor.py @@ -44,8 +44,10 @@ class DistributorTestCase(unittest.TestCase): self.dist.declare("whine") d_inner = defer.Deferred() + def observer(): return d_inner + self.dist.observe("whine", observer) d_outer = self.dist.fire("whine") @@ -66,8 +68,8 @@ class DistributorTestCase(unittest.TestCase): observers[0].side_effect = Exception("Awoogah!") - with patch("synapse.util.distributor.logger", - spec=["warning"] + with patch( + "synapse.util.distributor.logger", spec=["warning"] ) as mock_logger: d = self.dist.fire("alarm", "Go") yield d @@ -77,8 +79,9 @@ class DistributorTestCase(unittest.TestCase): observers[1].assert_called_once_with("Go") self.assertEquals(mock_logger.warning.call_count, 1) - self.assertIsInstance(mock_logger.warning.call_args[0][0], - str) + self.assertIsInstance( + mock_logger.warning.call_args[0][0], str + ) @defer.inlineCallbacks def test_signal_catch_no_suppress(self): diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py index 3718f4fc3f..d28bb726bb 100644 --- a/tests/test_test_utils.py +++ b/tests/test_test_utils.py @@ -17,6 +17,7 @@ from tests import unittest from tests.utils import MockClock + class MockClockTestCase(unittest.TestCase): def setUp(self): @@ -60,7 +61,7 @@ class MockClockTestCase(unittest.TestCase): def _cb1(): invoked[1] = 1 - t1 = self.clock.call_later(20, _cb1) + self.clock.call_later(20, _cb1) self.clock.cancel_call_later(t0) diff --git a/tests/unittest.py b/tests/unittest.py index 6f02eb4cac..5b22abfe74 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -37,9 +37,12 @@ def around(target): def _around(code): name = code.__name__ orig = getattr(target, name) + def new(*args, **kwargs): return code(orig, *args, **kwargs) + setattr(target, name, new) + return _around @@ -53,9 +56,7 @@ class TestCase(unittest.TestCase): method = getattr(self, methodName) - level = getattr(method, "loglevel", - getattr(self, "loglevel", - NEVER)) + level = getattr(method, "loglevel", getattr(self, "loglevel", NEVER)) @around(self) def setUp(orig): diff --git a/tests/util/__init__.py b/tests/util/__init__.py index d0e9399dda..bfebb0f644 100644 --- a/tests/util/__init__.py +++ b/tests/util/__init__.py @@ -12,4 +12,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py index 7bbe795622..272b71034a 100644 --- a/tests/util/test_dict_cache.py +++ b/tests/util/test_dict_cache.py @@ -14,7 +14,6 @@ # limitations under the License. -from twisted.internet import defer from tests import unittest from synapse.util.caches.dictionary_cache import DictionaryCache diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py index 4ee0d49673..7e289715ba 100644 --- a/tests/util/test_snapshot_cache.py +++ b/tests/util/test_snapshot_cache.py @@ -19,6 +19,7 @@ from .. import unittest from synapse.util.caches.snapshot_cache import SnapshotCache from twisted.internet.defer import Deferred + class SnapshotCacheTestCase(unittest.TestCase): def setUp(self): diff --git a/tests/util/test_treecache.py b/tests/util/test_treecache.py index 1efbeb6b33..bcc64422b2 100644 --- a/tests/util/test_treecache.py +++ b/tests/util/test_treecache.py @@ -18,6 +18,7 @@ from .. import unittest from synapse.util.caches.treecache import TreeCache + class TreeCacheTestCase(unittest.TestCase): def test_get_set_onelevel(self): cache = TreeCache() diff --git a/tests/utils.py b/tests/utils.py index f71125042b..bf7a31ff9e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -152,7 +152,7 @@ class MockHttpResource(HttpServer): mock_request.getClientIP.return_value = "-" - mock_request.requestHeaders.getRawHeaders.return_value=[ + mock_request.requestHeaders.getRawHeaders.return_value = [ "X-Matrix origin=test,key=,sig=" ] @@ -360,13 +360,12 @@ class MemoryDataStore(object): def get_rooms_for_user_where_membership_is(self, user_id, membership_list): return [ - self.members[r].get(user_id) for r in self.members - if user_id in self.members[r] and - self.members[r][user_id].membership in membership_list + m[user_id] for m in self.members.values() + if user_id in m and m[user_id].membership in membership_list ] def get_room_events_stream(self, user_id=None, from_key=None, to_key=None, - limit=0, with_feedback=False): + limit=0, with_feedback=False): return ([], from_key) # TODO def get_joined_hosts_for_room(self, room_id): @@ -376,7 +375,6 @@ class MemoryDataStore(object): if event.type == EventTypes.Member: room_id = event.room_id user = event.state_key - membership = event.membership self.members.setdefault(room_id, {})[user] = event if hasattr(event, "state_key"): @@ -456,9 +454,9 @@ class DeferredMockCallable(object): d.callback(None) return result - failure = AssertionError("Was not expecting call(%s)" % + failure = AssertionError("Was not expecting call(%s)" % ( _format_call(args, kwargs) - ) + )) for _, _, d in self.expectations: try: @@ -479,14 +477,12 @@ class DeferredMockCallable(object): ) timer = reactor.callLater( - timeout/1000, + timeout / 1000, deferred.errback, - AssertionError( - "%d pending calls left: %s"% ( - len([e for e in self.expectations if not e[2].called]), - [e for e in self.expectations if not e[2].called] - ) - ) + AssertionError("%d pending calls left: %s" % ( + len([e for e in self.expectations if not e[2].called]), + [e for e in self.expectations if not e[2].called] + )) ) yield deferred @@ -500,8 +496,8 @@ class DeferredMockCallable(object): calls = self.calls self.calls = [] - raise AssertionError("Expected not to received any calls, got:\n" + - "\n".join([ + raise AssertionError( + "Expected not to received any calls, got:\n" + "\n".join([ "call(%s)" % _format_call(c[0], c[1]) for c in calls ]) ) diff --git a/tox.ini b/tox.ini index 6f01599af7..757b7189c3 100644 --- a/tox.ini +++ b/tox.ini @@ -26,4 +26,4 @@ skip_install = True basepython = python2.7 deps = flake8 -commands = /bin/bash -c "flake8 synapse {env:PEP8SUFFIX:}" +commands = /bin/bash -c "flake8 synapse tests {env:PEP8SUFFIX:}" -- cgit 1.4.1 From 60a0f81c7a2da86bf959227a440e3f7a2b727bb5 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 1 Mar 2016 14:49:41 +0000 Subject: Add a /replication API for extracting the updates that happened on synapse This is necessary for replicating the data in synapse to be visible to a separate service because presence and typing notifications aren't stored in a database so won't be visible to another process. This API can be used to either get the raw data by requesting the tables themselves or to just receive notifications for updates by following the streams meta-stream. Returns updates for each table requested a JSON array of arrays with a row for each row in the table. Each table is prefixed by a header row with the: name of the table, current stream_id position for the table, number of rows, number of columns and the names of the columns. This is followed by the rows that have been added to the server since the requester last asked. The API has a timeout and is hooked up to the notifier so that a slave can long poll for updates. --- scripts-dev/tail-synapse.py | 67 ++++++++ synapse/app/homeserver.py | 4 + synapse/handlers/presence.py | 19 +++ synapse/handlers/typing.py | 14 ++ synapse/notifier.py | 48 ++++++ synapse/replication/__init__.py | 14 ++ synapse/replication/resource.py | 320 +++++++++++++++++++++++++++++++++++++ synapse/storage/account_data.py | 36 ++++- synapse/storage/events.py | 45 ++++++ synapse/storage/presence.py | 16 ++ synapse/storage/receipts.py | 16 ++ synapse/storage/tags.py | 53 ++++++ tests/replication/__init__.py | 14 ++ tests/replication/test_resource.py | 179 +++++++++++++++++++++ tests/utils.py | 5 +- 15 files changed, 846 insertions(+), 4 deletions(-) create mode 100644 scripts-dev/tail-synapse.py create mode 100644 synapse/replication/__init__.py create mode 100644 synapse/replication/resource.py create mode 100644 tests/replication/__init__.py create mode 100644 tests/replication/test_resource.py (limited to 'tests/utils.py') diff --git a/scripts-dev/tail-synapse.py b/scripts-dev/tail-synapse.py new file mode 100644 index 0000000000..18be711e92 --- /dev/null +++ b/scripts-dev/tail-synapse.py @@ -0,0 +1,67 @@ +import requests +import collections +import sys +import time +import json + +Entry = collections.namedtuple("Entry", "name position rows") + +ROW_TYPES = {} + + +def row_type_for_columns(name, column_names): + column_names = tuple(column_names) + row_type = ROW_TYPES.get((name, column_names)) + if row_type is None: + row_type = collections.namedtuple(name, column_names) + ROW_TYPES[(name, column_names)] = row_type + return row_type + + +def parse_response(content): + streams = json.loads(content) + result = {} + for name, value in streams.items(): + row_type = row_type_for_columns(name, value["field_names"]) + position = value["position"] + rows = [row_type(*row) for row in value["rows"]] + result[name] = Entry(name, position, rows) + return result + + +def replicate(server, streams): + return parse_response(requests.get( + server + "/_synapse/replication", + verify=False, + params=streams + ).content) + + +def main(): + server = sys.argv[1] + + streams = None + while not streams: + try: + streams = { + row.name: row.position + for row in replicate(server, {"streams":"-1"})["streams"].rows + } + except requests.exceptions.ConnectionError as e: + time.sleep(0.1) + + while True: + try: + results = replicate(server, streams) + except: + sys.stdout.write("connection_lost("+ repr(streams) + ")\n") + break + for update in results.values(): + for row in update.rows: + sys.stdout.write(repr(row) + "\n") + streams[update.name] = update.position + + + +if __name__=='__main__': + main() diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 2b4be7bdd0..de5ee988f1 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -63,6 +63,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory from synapse.util.logcontext import LoggingContext from synapse.metrics.resource import MetricsResource, METRICS_PREFIX +from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX from synapse.federation.transport.server import TransportLayerServer from synapse import events @@ -169,6 +170,9 @@ class SynapseHomeServer(HomeServer): if name == "metrics" and self.get_config().enable_metrics: resources[METRICS_PREFIX] = MetricsResource(self) + if name == "replication": + resources[REPLICATION_PREFIX] = ReplicationResource(self) + root_resource = create_resource_tree(resources) if tls: reactor.listenSSL( diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 08e38cdd25..d98e80086e 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -774,6 +774,25 @@ class PresenceHandler(BaseHandler): defer.returnValue(observer_user.to_string() in accepted_observers) + @defer.inlineCallbacks + def get_all_presence_updates(self, last_id, current_id): + """ + Gets a list of presence update rows from between the given stream ids. + Each row has: + - stream_id(str) + - user_id(str) + - state(str) + - last_active_ts(int) + - last_federation_update_ts(int) + - last_user_sync_ts(int) + - status_msg(int) + - currently_active(int) + """ + # TODO(markjh): replicate the unpersisted changes. + # This could use the in-memory stores for recent changes. + rows = yield self.store.get_all_presence_updates(last_id, current_id) + defer.returnValue(rows) + def should_notify(old_state, new_state): """Decides if a presence state change should be sent to interested parties. diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index b16d0017df..8ce27f49ec 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -25,6 +25,7 @@ from synapse.types import UserID import logging from collections import namedtuple +import ujson as json logger = logging.getLogger(__name__) @@ -219,6 +220,19 @@ class TypingNotificationHandler(BaseHandler): "typing_key", self._latest_room_serial, rooms=[room_id] ) + def get_all_typing_updates(self, last_id, current_id): + # TODO: Work out a way to do this without scanning the entire state. + rows = [] + for room_id, serial in self._room_serials.items(): + if last_id < serial and serial <= current_id: + typing = self._room_typing[room_id] + typing_bytes = json.dumps([ + u.to_string() for u in typing + ], ensure_ascii=False) + rows.append((serial, room_id, typing_bytes)) + rows.sort() + return rows + class TypingNotificationEventSource(object): def __init__(self, hs): diff --git a/synapse/notifier.py b/synapse/notifier.py index 560866b26e..3c36a20868 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -159,6 +159,8 @@ class Notifier(object): self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS ) + self.replication_deferred = ObservableDeferred(defer.Deferred()) + # This is not a very cheap test to perform, but it's only executed # when rendering the metrics page, which is likely once per minute at # most when scraping it. @@ -207,6 +209,8 @@ class Notifier(object): )) self._notify_pending_new_room_events(max_room_stream_id) + self.notify_replication() + def _notify_pending_new_room_events(self, max_room_stream_id): """Notify for the room events that were queued waiting for a previous event to be persisted. @@ -276,6 +280,8 @@ class Notifier(object): except: logger.exception("Failed to notify listener") + self.notify_replication() + @defer.inlineCallbacks def wait_for_events(self, user_id, timeout, callback, room_ids=None, from_token=StreamToken("s0", "0", "0", "0", "0")): @@ -479,3 +485,45 @@ class Notifier(object): room_streams = self.room_to_user_streams.setdefault(room_id, set()) room_streams.add(new_user_stream) new_user_stream.rooms.add(room_id) + + def notify_replication(self): + """Notify the any replication listeners that there's a new event""" + with PreserveLoggingContext(): + deferred = self.replication_deferred + self.replication_deferred = ObservableDeferred(defer.Deferred()) + deferred.callback(None) + + @defer.inlineCallbacks + def wait_for_replication(self, callback, timeout): + """Wait for an event to happen. + + :param callback: + Gets called whenever an event happens. If this returns a truthy + value then ``wait_for_replication`` returns, otherwise it waits + for another event. + :param int timeout: + How many milliseconds to wait for callback return a truthy value. + :returns: + A deferred that resolves with the value returned by the callback. + """ + listener = _NotificationListener(None) + + def timed_out(): + listener.deferred.cancel() + + timer = self.clock.call_later(timeout / 1000., timed_out) + while True: + listener.deferred = self.replication_deferred.observe() + result = yield callback() + if result: + break + + try: + with PreserveLoggingContext(): + yield listener.deferred + except defer.CancelledError: + break + + self.clock.cancel_call_later(timer, ignore_errs=True) + + defer.returnValue(result) diff --git a/synapse/replication/__init__.py b/synapse/replication/__init__.py new file mode 100644 index 0000000000..b7df13c9ee --- /dev/null +++ b/synapse/replication/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py new file mode 100644 index 0000000000..e0d039518d --- /dev/null +++ b/synapse/replication/resource.py @@ -0,0 +1,320 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.http.servlet import parse_integer, parse_string +from synapse.http.server import request_handler, finish_request + +from twisted.web.resource import Resource +from twisted.web.server import NOT_DONE_YET +from twisted.internet import defer + +import ujson as json + +import collections +import logging + +logger = logging.getLogger(__name__) + +REPLICATION_PREFIX = "/_synapse/replication" + +STREAM_NAMES = ( + ("events",), + ("presence",), + ("typing",), + ("receipts",), + ("user_account_data", "room_account_data", "tag_account_data",), + ("backfill",), +) + + +class ReplicationResource(Resource): + """ + HTTP endpoint for extracting data from synapse. + + The streams of data returned by the endpoint are controlled by the + parameters given to the API. To return a given stream pass a query + parameter with a position in the stream to return data from or the + special value "-1" to return data from the start of the stream. + + If there is no data for any of the supplied streams after the given + position then the request will block until there is data for one + of the streams. This allows clients to long-poll this API. + + The possible streams are: + + * "streams": A special stream returing the positions of other streams. + * "events": The new events seen on the server. + * "presence": Presence updates. + * "typing": Typing updates. + * "receipts": Receipt updates. + * "user_account_data": Top-level per user account data. + * "room_account_data: Per room per user account data. + * "tag_account_data": Per room per user tags. + * "backfill": Old events that have been backfilled from other servers. + + The API takes two additional query parameters: + + * "timeout": How long to wait before returning an empty response. + * "limit": The maximum number of rows to return for the selected streams. + + The response is a JSON object with keys for each stream with updates. Under + each key is a JSON object with: + + * "postion": The current position of the stream. + * "field_names": The names of the fields in each row. + * "rows": The updates as an array of arrays. + + There are a number of ways this API could be used: + + 1) To replicate the contents of the backing database to another database. + 2) To be notified when the contents of a shared backing database changes. + 3) To "tail" the activity happening on a server for debugging. + + In the first case the client would track all of the streams and store it's + own copy of the data. + + In the second case the client might theoretically just be able to follow + the "streams" stream to track where the other streams are. However in + practise it will probably need to get the contents of the streams in + order to expire the any in-memory caches. Whether it gets the contents + of the streams from this replication API or directly from the backing + store is a matter of taste. + + In the third case the client would use the "streams" stream to find what + streams are available and their current positions. Then it can start + long-polling this replication API for new data on those streams. + """ + + isLeaf = True + + def __init__(self, hs): + Resource.__init__(self) # Resource is old-style, so no super() + + self.version_string = hs.version_string + self.store = hs.get_datastore() + self.sources = hs.get_event_sources() + self.presence_handler = hs.get_handlers().presence_handler + self.typing_handler = hs.get_handlers().typing_notification_handler + self.notifier = hs.notifier + + def render_GET(self, request): + self._async_render_GET(request) + return NOT_DONE_YET + + @defer.inlineCallbacks + def current_replication_token(self): + stream_token = yield self.sources.get_current_token() + backfill_token = yield self.store.get_current_backfill_token() + + defer.returnValue(_ReplicationToken( + stream_token.room_stream_id, + int(stream_token.presence_key), + int(stream_token.typing_key), + int(stream_token.receipt_key), + int(stream_token.account_data_key), + backfill_token, + )) + + @request_handler + @defer.inlineCallbacks + def _async_render_GET(self, request): + limit = parse_integer(request, "limit", 100) + timeout = parse_integer(request, "timeout", 10 * 1000) + + request.setHeader(b"Content-Type", b"application/json") + writer = _Writer(request) + + @defer.inlineCallbacks + def replicate(): + current_token = yield self.current_replication_token() + logger.info("Replicating up to %r", current_token) + + yield self.account_data(writer, current_token, limit) + yield self.events(writer, current_token, limit) + yield self.presence(writer, current_token) # TODO: implement limit + yield self.typing(writer, current_token) # TODO: implement limit + yield self.receipts(writer, current_token, limit) + self.streams(writer, current_token) + + logger.info("Replicated %d rows", writer.total) + defer.returnValue(writer.total) + + yield self.notifier.wait_for_replication(replicate, timeout) + + writer.finish() + + def streams(self, writer, current_token): + request_token = parse_string(writer.request, "streams") + + streams = [] + + if request_token is not None: + if request_token == "-1": + for names, position in zip(STREAM_NAMES, current_token): + streams.extend((name, position) for name in names) + else: + items = zip( + STREAM_NAMES, + current_token, + _ReplicationToken(request_token) + ) + for names, current_id, last_id in items: + if last_id < current_id: + streams.extend((name, current_id) for name in names) + + if streams: + writer.write_header_and_rows( + "streams", streams, ("name", "position"), + position=str(current_token) + ) + + @defer.inlineCallbacks + def events(self, writer, current_token, limit): + request_events = parse_integer(writer.request, "events") + request_backfill = parse_integer(writer.request, "backfill") + + if request_events is not None or request_backfill is not None: + if request_events is None: + request_events = current_token.events + if request_backfill is None: + request_backfill = current_token.backfill + events_rows, backfill_rows = yield self.store.get_all_new_events( + request_backfill, request_events, + current_token.backfill, current_token.events, + limit + ) + writer.write_header_and_rows( + "events", events_rows, ("position", "internal", "json") + ) + writer.write_header_and_rows( + "backfill", backfill_rows, ("position", "internal", "json") + ) + + @defer.inlineCallbacks + def presence(self, writer, current_token): + current_position = current_token.presence + + request_presence = parse_integer(writer.request, "presence") + + if request_presence is not None: + presence_rows = yield self.presence_handler.get_all_presence_updates( + request_presence, current_position + ) + writer.write_header_and_rows("presence", presence_rows, ( + "position", "user_id", "state", "last_active_ts", + "last_federation_update_ts", "last_user_sync_ts", + "status_msg", "currently_active", + )) + + @defer.inlineCallbacks + def typing(self, writer, current_token): + current_position = current_token.presence + + request_typing = parse_integer(writer.request, "typing") + + if request_typing is not None: + typing_rows = yield self.typing_handler.get_all_typing_updates( + request_typing, current_position + ) + writer.write_header_and_rows("typing", typing_rows, ( + "position", "room_id", "typing" + )) + + @defer.inlineCallbacks + def receipts(self, writer, current_token, limit): + current_position = current_token.receipts + + request_receipts = parse_integer(writer.request, "receipts") + + if request_receipts is not None: + receipts_rows = yield self.store.get_all_updated_receipts( + request_receipts, current_position, limit + ) + writer.write_header_and_rows("receipts", receipts_rows, ( + "position", "room_id", "receipt_type", "user_id", "event_id", "data" + )) + + @defer.inlineCallbacks + def account_data(self, writer, current_token, limit): + current_position = current_token.account_data + + user_account_data = parse_integer(writer.request, "user_account_data") + room_account_data = parse_integer(writer.request, "room_account_data") + tag_account_data = parse_integer(writer.request, "tag_account_data") + + if user_account_data is not None or room_account_data is not None: + if user_account_data is None: + user_account_data = current_position + if room_account_data is None: + room_account_data = current_position + user_rows, room_rows = yield self.store.get_all_updated_account_data( + user_account_data, room_account_data, current_position, limit + ) + writer.write_header_and_rows("user_account_data", user_rows, ( + "position", "user_id", "type", "content" + )) + writer.write_header_and_rows("room_account_data", room_rows, ( + "position", "user_id", "room_id", "type", "content" + )) + + if tag_account_data is not None: + tag_rows = yield self.store.get_all_updated_tags( + tag_account_data, current_position, limit + ) + writer.write_header_and_rows("tag_account_data", tag_rows, ( + "position", "user_id", "room_id", "tags" + )) + + +class _Writer(object): + """Writes the streams as a JSON object as the response to the request""" + def __init__(self, request): + self.streams = {} + self.request = request + self.total = 0 + + def write_header_and_rows(self, name, rows, fields, position=None): + if not rows: + return + + if position is None: + position = rows[-1][0] + + self.streams[name] = { + "position": str(position), + "field_names": fields, + "rows": rows, + } + + self.total += len(rows) + + def finish(self): + self.request.write(json.dumps(self.streams, ensure_ascii=False)) + finish_request(self.request) + + +class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( + "events", "presence", "typing", "receipts", "account_data", "backfill", +))): + __slots__ = [] + + def __new__(cls, *args): + if len(args) == 1: + return cls(*(int(value) for value in args[0].split("_"))) + else: + return super(_ReplicationToken, cls).__new__(cls, *args) + + def __str__(self): + return "_".join(str(value) for value in self) diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index 91cbf399b6..05f98a9a29 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -83,8 +83,40 @@ class AccountDataStore(SQLBaseStore): "get_account_data_for_room", get_account_data_for_room_txn ) - def get_updated_account_data_for_user(self, user_id, stream_id, room_ids=None): - """Get all the client account_data for a that's changed. + def get_all_updated_account_data(self, last_global_id, last_room_id, + current_id, limit): + """Get all the client account_data that has changed on the server + Args: + last_global_id(int): The position to fetch from for top level data + last_room_id(int): The position to fetch from for per room data + current_id(int): The position to fetch up to. + Returns: + A deferred pair of lists of tuples of stream_id int, user_id string, + room_id string, type string, and content string. + """ + def get_updated_account_data_txn(txn): + sql = ( + "SELECT stream_id, user_id, account_data_type, content" + " FROM account_data WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_global_id, current_id, limit)) + global_results = txn.fetchall() + + sql = ( + "SELECT stream_id, user_id, room_id, account_data_type, content" + " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_room_id, current_id, limit)) + room_results = txn.fetchall() + return (global_results, room_results) + return self.runInteraction( + "get_all_updated_account_data_txn", get_updated_account_data_txn + ) + + def get_updated_account_data_for_user(self, user_id, stream_id): + """Get all the client account_data for a that's changed for a user Args: user_id(str): The user to get the account_data for. diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 1dd3236829..c0872dd7e2 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1064,3 +1064,48 @@ class EventsStore(SQLBaseStore): yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME) defer.returnValue(result) + + def get_current_backfill_token(self): + """The current minimum token that backfilled events have reached""" + + # TODO: Fix race with the persit_event txn by using one of the + # stream id managers + return -self.min_stream_token + + def get_all_new_events(self, last_backfill_id, last_forward_id, + current_backfill_id, current_forward_id, limit): + """Get all the new events that have arrived at the server either as + new events or as backfilled events""" + def get_all_new_events_txn(txn): + sql = ( + "SELECT e.stream_ordering, ej.internal_metadata, ej.json" + " FROM events as e" + " JOIN event_json as ej" + " ON e.event_id = ej.event_id AND e.room_id = ej.room_id" + " WHERE ? < e.stream_ordering AND e.stream_ordering <= ?" + " ORDER BY e.stream_ordering ASC" + " LIMIT ?" + ) + if last_forward_id != current_forward_id: + txn.execute(sql, (last_forward_id, current_forward_id, limit)) + new_forward_events = txn.fetchall() + else: + new_forward_events = [] + + sql = ( + "SELECT -e.stream_ordering, ej.internal_metadata, ej.json" + " FROM events as e" + " JOIN event_json as ej" + " ON e.event_id = ej.event_id AND e.room_id = ej.room_id" + " WHERE ? > e.stream_ordering AND e.stream_ordering >= ?" + " ORDER BY e.stream_ordering DESC" + " LIMIT ?" + ) + if last_backfill_id != current_backfill_id: + txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit)) + new_backfill_events = txn.fetchall() + else: + new_backfill_events = [] + + return (new_forward_events, new_backfill_events) + return self.runInteraction("get_all_new_events", get_all_new_events_txn) diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 3ef91d34db..de15741893 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -115,6 +115,22 @@ class PresenceStore(SQLBaseStore): args ) + def get_all_presence_updates(self, last_id, current_id): + def get_all_presence_updates_txn(txn): + sql = ( + "SELECT stream_id, user_id, state, last_active_ts," + " last_federation_update_ts, last_user_sync_ts, status_msg," + " currently_active" + " FROM presence_stream" + " WHERE ? < stream_id AND stream_id <= ?" + ) + txn.execute(sql, (last_id, current_id)) + return txn.fetchall() + + return self.runInteraction( + "get_all_presence_updates", get_all_presence_updates_txn + ) + @defer.inlineCallbacks def get_presence_for_users(self, user_ids): rows = yield self._simple_select_many_batch( diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index a7343c97f7..6567fa844f 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -390,3 +390,19 @@ class ReceiptsStore(SQLBaseStore): "data": json.dumps(data), } ) + + def get_all_updated_receipts(self, last_id, current_id, limit): + def get_all_updated_receipts_txn(txn): + sql = ( + "SELECT stream_id, room_id, receipt_type, user_id, event_id, data" + " FROM receipts_linearized" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + " LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + + return txn.fetchall() + return self.runInteraction( + "get_all_updated_receipts", get_all_updated_receipts_txn + ) diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index 9551aa9739..b225f508a5 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -58,6 +58,59 @@ class TagsStore(SQLBaseStore): return deferred + @defer.inlineCallbacks + def get_all_updated_tags(self, last_id, current_id, limit): + """Get all the client tags that have changed on the server + Args: + last_id(int): The position to fetch from. + current_id(int): The position to fetch up to. + Returns: + A deferred list of tuples of stream_id int, user_id string, + room_id string, tag string and content string. + """ + def get_all_updated_tags_txn(txn): + sql = ( + "SELECT stream_id, user_id, room_id" + " FROM room_tags_revisions as r" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() + + tag_ids = yield self.runInteraction( + "get_all_updated_tags", get_all_updated_tags_txn + ) + + def get_tag_content(txn, tag_ids): + sql = ( + "SELECT tag, content" + " FROM room_tags" + " WHERE user_id=? AND room_id=?" + ) + results = [] + for stream_id, user_id, room_id in tag_ids: + txn.execute(sql, (user_id, room_id)) + tags = [] + for tag, content in txn.fetchall(): + tags.append(json.dumps(tag) + ":" + content) + tag_json = "{" + ",".join(tags) + "}" + results.append((stream_id, user_id, room_id, tag_json)) + + return results + + batch_size = 50 + results = [] + for i in xrange(0, len(tag_ids), batch_size): + tags = yield self.runInteraction( + "get_all_updated_tag_content", + get_tag_content, + tag_ids[i:i + batch_size], + ) + results.extend(tags) + + defer.returnValue(results) + @defer.inlineCallbacks def get_updated_tags(self, user_id, stream_id): """Get all the tags for the rooms where the tags have changed since the diff --git a/tests/replication/__init__.py b/tests/replication/__init__.py new file mode 100644 index 0000000000..b7df13c9ee --- /dev/null +++ b/tests/replication/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py new file mode 100644 index 0000000000..38daaf87e2 --- /dev/null +++ b/tests/replication/test_resource.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.replication.resource import ReplicationResource +from synapse.types import Requester, UserID + +from twisted.internet import defer +from tests import unittest +from tests.utils import setup_test_homeserver +from mock import Mock, NonCallableMock +import json +import contextlib + + +class ReplicationResourceCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver( + "red", + http_client=None, + replication_layer=Mock(), + ratelimiter=NonCallableMock(spec_set=[ + "send_message", + ]), + ) + self.user = UserID.from_string("@seeing:red") + + self.hs.get_ratelimiter().send_message.return_value = (True, 0) + + self.resource = ReplicationResource(self.hs) + + @defer.inlineCallbacks + def test_streams(self): + # Passing "-1" returns the current stream positions + code, body = yield self.get(streams="-1") + self.assertEquals(code, 200) + self.assertEquals(body["streams"]["field_names"], ["name", "position"]) + position = body["streams"]["position"] + # Passing the current position returns an empty response after the + # timeout + get = self.get(streams=str(position), timeout="0") + self.hs.clock.advance_time_msec(1) + code, body = yield get + self.assertEquals(code, 200) + self.assertEquals(body, {}) + + @defer.inlineCallbacks + def test_events(self): + get = self.get(events="-1", timeout="0") + yield self.hs.get_handlers().room_creation_handler.create_room( + Requester(self.user, "", False), {} + ) + code, body = yield get + self.assertEquals(code, 200) + self.assertEquals(body["events"]["field_names"], [ + "position", "internal", "json" + ]) + + @defer.inlineCallbacks + def test_presence(self): + get = self.get(presence="-1") + yield self.hs.get_handlers().presence_handler.set_state( + self.user, {"presence": "online"} + ) + code, body = yield get + self.assertEquals(code, 200) + self.assertEquals(body["presence"]["field_names"], [ + "position", "user_id", "state", "last_active_ts", + "last_federation_update_ts", "last_user_sync_ts", + "status_msg", "currently_active", + ]) + + @defer.inlineCallbacks + def test_typing(self): + room_id = yield self.create_room() + get = self.get(typing="-1") + yield self.hs.get_handlers().typing_notification_handler.started_typing( + self.user, self.user, room_id, timeout=2 + ) + code, body = yield get + self.assertEquals(code, 200) + self.assertEquals(body["typing"]["field_names"], [ + "position", "room_id", "typing" + ]) + + @defer.inlineCallbacks + def test_receipts(self): + room_id = yield self.create_room() + event_id = yield self.send_text_message(room_id, "Hello, World") + get = self.get(receipts="-1") + yield self.hs.get_handlers().receipts_handler.received_client_receipt( + room_id, "m.read", self.user.to_string(), event_id + ) + code, body = yield get + self.assertEquals(code, 200) + self.assertEquals(body["receipts"]["field_names"], [ + "position", "room_id", "receipt_type", "user_id", "event_id", "data" + ]) + + def _test_timeout(stream): + """Check that a request for the given stream timesout""" + @defer.inlineCallbacks + def test_timeout(self): + get = self.get(**{stream: "-1", "timeout": "0"}) + self.hs.clock.advance_time_msec(1) + code, body = yield get + self.assertEquals(code, 200) + self.assertEquals(body, {}) + test_timeout.__name__ = "test_timeout_%s" % (stream) + return test_timeout + + test_timeout_events = _test_timeout("events") + test_timeout_presence = _test_timeout("presence") + test_timeout_typing = _test_timeout("typing") + test_timeout_receipts = _test_timeout("receipts") + test_timeout_user_account_data = _test_timeout("user_account_data") + test_timeout_room_account_data = _test_timeout("room_account_data") + test_timeout_tag_account_data = _test_timeout("tag_account_data") + test_timeout_backfill = _test_timeout("backfill") + + @defer.inlineCallbacks + def send_text_message(self, room_id, message): + handler = self.hs.get_handlers().message_handler + event = yield handler.create_and_send_nonmember_event({ + "type": "m.room.message", + "content": {"body": "message", "msgtype": "m.text"}, + "room_id": room_id, + "sender": self.user.to_string(), + }) + defer.returnValue(event.event_id) + + @defer.inlineCallbacks + def create_room(self): + result = yield self.hs.get_handlers().room_creation_handler.create_room( + Requester(self.user, "", False), {} + ) + defer.returnValue(result["room_id"]) + + @defer.inlineCallbacks + def get(self, **params): + request = NonCallableMock(spec_set=[ + "write", "finish", "setResponseCode", "setHeader", "args", + "method", "processing" + ]) + + request.method = "GET" + request.args = {k: [v] for k, v in params.items()} + + @contextlib.contextmanager + def processing(): + yield + request.processing = processing + + yield self.resource._async_render_GET(request) + self.assertTrue(request.finish.called) + + if request.setResponseCode.called: + response_code = request.setResponseCode.call_args[0][0] + else: + response_code = 200 + + response_json = "".join( + call[0][0] for call in request.write.call_args_list + ) + response_body = json.loads(response_json) + + defer.returnValue((response_code, response_body)) diff --git a/tests/utils.py b/tests/utils.py index bf7a31ff9e..dfbee5c23a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -239,9 +239,10 @@ class MockClock(object): def looping_call(self, function, interval): pass - def cancel_call_later(self, timer): + def cancel_call_later(self, timer, ignore_errs=False): if timer[2]: - raise Exception("Cannot cancel an expired timer") + if not ignore_errs: + raise Exception("Cannot cancel an expired timer") timer[2] = True self.timers = [t for t in self.timers if t != timer] -- cgit 1.4.1 From b4022cc487921ec46942a6a72fb174bb7aa1e459 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Thu, 3 Mar 2016 16:43:42 +0000 Subject: Pass whole requester to ratelimiting This will enable more detailed decisions --- synapse/handlers/_base.py | 15 +++++-- synapse/handlers/directory.py | 20 ++++++---- synapse/handlers/federation.py | 4 +- synapse/handlers/message.py | 8 ++-- synapse/handlers/profile.py | 17 ++++---- synapse/handlers/room.py | 76 +++++++++++++++++++++--------------- synapse/rest/client/v1/directory.py | 6 ++- synapse/rest/client/v1/profile.py | 4 +- synapse/rest/client/v1/room.py | 8 ++-- tests/handlers/test_profile.py | 16 ++++++-- tests/replication/test_resource.py | 17 ++++---- tests/rest/client/v1/test_profile.py | 4 +- tests/utils.py | 5 +++ 13 files changed, 124 insertions(+), 76 deletions(-) (limited to 'tests/utils.py') diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index bdade98bf7..2333fc0c09 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -160,10 +160,10 @@ class BaseHandler(object): ) defer.returnValue(res.get(user_id, [])) - def ratelimit(self, user_id): + def ratelimit(self, requester): time_now = self.clock.time() allowed, time_allowed = self.ratelimiter.send_message( - user_id, time_now, + requester.user.to_string(), time_now, msg_rate_hz=self.hs.config.rc_messages_per_second, burst_count=self.hs.config.rc_message_burst_count, ) @@ -263,11 +263,18 @@ class BaseHandler(object): return False @defer.inlineCallbacks - def handle_new_client_event(self, event, context, ratelimit=True, extra_users=[]): + def handle_new_client_event( + self, + requester, + event, + context, + ratelimit=True, + extra_users=[] + ): # We now need to go and hit out to wherever we need to hit out to. if ratelimit: - self.ratelimit(event.sender) + self.ratelimit(requester) self.auth.check(event, auth_events=context.current_state) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index e0a778e7ff..88166f0187 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -212,17 +212,21 @@ class DirectoryHandler(BaseHandler): ) @defer.inlineCallbacks - def send_room_alias_update_event(self, user_id, room_id): + def send_room_alias_update_event(self, requester, user_id, room_id): aliases = yield self.store.get_aliases_for_room(room_id) msg_handler = self.hs.get_handlers().message_handler - yield msg_handler.create_and_send_nonmember_event({ - "type": EventTypes.Aliases, - "state_key": self.hs.hostname, - "room_id": room_id, - "sender": user_id, - "content": {"aliases": aliases}, - }, ratelimit=False) + yield msg_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Aliases, + "state_key": self.hs.hostname, + "room_id": room_id, + "sender": user_id, + "content": {"aliases": aliases}, + }, + ratelimit=False + ) @defer.inlineCallbacks def get_association_from_room_alias(self, room_alias): diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3655b9e5e2..6e50b0963e 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1657,7 +1657,7 @@ class FederationHandler(BaseHandler): self.auth.check(event, context.current_state) yield self._check_signature(event, auth_events=context.current_state) member_handler = self.hs.get_handlers().room_member_handler - yield member_handler.send_membership_event(event, context, from_client=False) + yield member_handler.send_membership_event(None, event, context) else: destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id)) yield self.replication_layer.forward_third_party_invite( @@ -1686,7 +1686,7 @@ class FederationHandler(BaseHandler): # TODO: Make sure the signatures actually are correct. event.signatures.update(returned_invite.signatures) member_handler = self.hs.get_handlers().room_member_handler - yield member_handler.send_membership_event(event, context, from_client=False) + yield member_handler.send_membership_event(None, event, context) @defer.inlineCallbacks def add_display_name_to_third_party_invite(self, event_dict, event, context): diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index afa7c9c36c..cace1cb82a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -215,7 +215,7 @@ class MessageHandler(BaseHandler): defer.returnValue((event, context)) @defer.inlineCallbacks - def send_nonmember_event(self, event, context, ratelimit=True): + def send_nonmember_event(self, requester, event, context, ratelimit=True): """ Persists and notifies local clients and federation of an event. @@ -241,6 +241,7 @@ class MessageHandler(BaseHandler): defer.returnValue(prev_state) yield self.handle_new_client_event( + requester=requester, event=event, context=context, ratelimit=ratelimit, @@ -268,9 +269,9 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def create_and_send_nonmember_event( self, + requester, event_dict, ratelimit=True, - token_id=None, txn_id=None ): """ @@ -280,10 +281,11 @@ class MessageHandler(BaseHandler): """ event, context = yield self.create_event( event_dict, - token_id=token_id, + token_id=requester.access_token_id, txn_id=txn_id ) yield self.send_nonmember_event( + requester, event, context, ratelimit=ratelimit, diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index c9ad5944e6..b45eafbb49 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -89,13 +89,13 @@ class ProfileHandler(BaseHandler): defer.returnValue(result["displayname"]) @defer.inlineCallbacks - def set_displayname(self, target_user, auth_user, new_displayname): + def set_displayname(self, target_user, requester, new_displayname): """target_user is the user whose displayname is to be changed; auth_user is the user attempting to make this change.""" if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this Home Server") - if target_user != auth_user: + if target_user != requester.user: raise AuthError(400, "Cannot set another user's displayname") if new_displayname == '': @@ -109,7 +109,7 @@ class ProfileHandler(BaseHandler): "displayname": new_displayname, }) - yield self._update_join_states(target_user) + yield self._update_join_states(requester) @defer.inlineCallbacks def get_avatar_url(self, target_user): @@ -139,13 +139,13 @@ class ProfileHandler(BaseHandler): defer.returnValue(result["avatar_url"]) @defer.inlineCallbacks - def set_avatar_url(self, target_user, auth_user, new_avatar_url): + def set_avatar_url(self, target_user, requester, new_avatar_url): """target_user is the user whose avatar_url is to be changed; auth_user is the user attempting to make this change.""" if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this Home Server") - if target_user != auth_user: + if target_user != requester.user: raise AuthError(400, "Cannot set another user's avatar_url") yield self.store.set_profile_avatar_url( @@ -156,7 +156,7 @@ class ProfileHandler(BaseHandler): "avatar_url": new_avatar_url, }) - yield self._update_join_states(target_user) + yield self._update_join_states(requester) @defer.inlineCallbacks def collect_presencelike_data(self, user, state): @@ -199,11 +199,12 @@ class ProfileHandler(BaseHandler): defer.returnValue(response) @defer.inlineCallbacks - def _update_join_states(self, user): + def _update_join_states(self, requester): + user = requester.user if not self.hs.is_mine(user): return - self.ratelimit(user.to_string()) + self.ratelimit(requester) joins = yield self.store.get_rooms_for_user( user.to_string(), diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index d2de23a6cc..91fe306cf4 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -18,7 +18,7 @@ from twisted.internet import defer from ._base import BaseHandler -from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken +from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken, Requester from synapse.api.constants import ( EventTypes, Membership, JoinRules, RoomCreationPreset, ) @@ -90,7 +90,7 @@ class RoomCreationHandler(BaseHandler): """ user_id = requester.user.to_string() - self.ratelimit(user_id) + self.ratelimit(requester) if "room_alias_name" in config: for wchar in string.whitespace: @@ -185,23 +185,29 @@ class RoomCreationHandler(BaseHandler): if "name" in config: name = config["name"] - yield msg_handler.create_and_send_nonmember_event({ - "type": EventTypes.Name, - "room_id": room_id, - "sender": user_id, - "state_key": "", - "content": {"name": name}, - }, ratelimit=False) + yield msg_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Name, + "room_id": room_id, + "sender": user_id, + "state_key": "", + "content": {"name": name}, + }, + ratelimit=False) if "topic" in config: topic = config["topic"] - yield msg_handler.create_and_send_nonmember_event({ - "type": EventTypes.Topic, - "room_id": room_id, - "sender": user_id, - "state_key": "", - "content": {"topic": topic}, - }, ratelimit=False) + yield msg_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Topic, + "room_id": room_id, + "sender": user_id, + "state_key": "", + "content": {"topic": topic}, + }, + ratelimit=False) for invitee in invite_list: room_member_handler.update_membership( @@ -231,7 +237,7 @@ class RoomCreationHandler(BaseHandler): if room_alias: result["room_alias"] = room_alias.to_string() yield directory_handler.send_room_alias_update_event( - user_id, room_id + requester, user_id, room_id ) defer.returnValue(result) @@ -263,7 +269,11 @@ class RoomCreationHandler(BaseHandler): @defer.inlineCallbacks def send(etype, content, **kwargs): event = create(etype, content, **kwargs) - yield msg_handler.create_and_send_nonmember_event(event, ratelimit=False) + yield msg_handler.create_and_send_nonmember_event( + creator, + event, + ratelimit=False + ) config = RoomCreationHandler.PRESETS_DICT[preset_config] @@ -454,12 +464,11 @@ class RoomMemberHandler(BaseHandler): member_handler = self.hs.get_handlers().room_member_handler yield member_handler.send_membership_event( + requester, event, context, - is_guest=requester.is_guest, ratelimit=ratelimit, remote_room_hosts=remote_room_hosts, - from_client=True, ) if action == "forget": @@ -468,17 +477,19 @@ class RoomMemberHandler(BaseHandler): @defer.inlineCallbacks def send_membership_event( self, + requester, event, context, - is_guest=False, remote_room_hosts=None, ratelimit=True, - from_client=True, ): """ Change the membership status of a user in a room. Args: + requester (Requester): The local user who requested the membership + event. If None, certain checks, like whether this homeserver can + act as the sender, will be skipped. event (SynapseEvent): The membership event. context: The context of the event. is_guest (bool): Whether the sender is a guest. @@ -486,19 +497,21 @@ class RoomMemberHandler(BaseHandler): the room, and could be danced with in order to join this homeserver for the first time. ratelimit (bool): Whether to rate limit this request. - from_client (bool): Whether this request is the result of a local - client request (rather than over federation). If so, we will - perform extra checks, like that this homeserver can act as this - client. Raises: SynapseError if there was a problem changing the membership. """ target_user = UserID.from_string(event.state_key) room_id = event.room_id - if from_client: + if requester is not None: sender = UserID.from_string(event.sender) + assert sender == requester.user, ( + "Sender (%s) must be same as requester (%s)" % + (sender, requester.user) + ) assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) + else: + requester = Requester(target_user, None, False) message_handler = self.hs.get_handlers().message_handler prev_event = message_handler.deduplicate_state_event(event, context) @@ -508,7 +521,7 @@ class RoomMemberHandler(BaseHandler): action = "send" if event.membership == Membership.JOIN: - if is_guest and not self._can_guest_join(context.current_state): + if requester.is_guest and not self._can_guest_join(context.current_state): # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") @@ -551,6 +564,7 @@ class RoomMemberHandler(BaseHandler): ) else: yield self.handle_new_client_event( + requester, event, context, extra_users=[target_user], @@ -669,12 +683,12 @@ class RoomMemberHandler(BaseHandler): ) else: yield self._make_and_store_3pid_invite( + requester, id_server, medium, address, room_id, inviter, - requester.access_token_id, txn_id=txn_id ) @@ -732,12 +746,12 @@ class RoomMemberHandler(BaseHandler): @defer.inlineCallbacks def _make_and_store_3pid_invite( self, + requester, id_server, medium, address, room_id, user, - token_id, txn_id ): room_state = yield self.hs.get_state_handler().get_current_state(room_id) @@ -787,6 +801,7 @@ class RoomMemberHandler(BaseHandler): msg_handler = self.hs.get_handlers().message_handler yield msg_handler.create_and_send_nonmember_event( + requester, { "type": EventTypes.ThirdPartyInvite, "content": { @@ -801,7 +816,6 @@ class RoomMemberHandler(BaseHandler): "sender": user.to_string(), "state_key": token, }, - token_id=token_id, txn_id=txn_id, ) diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 74ec1e50e0..8c1a2614a0 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -75,7 +75,11 @@ class ClientDirectoryServer(ClientV1RestServlet): yield dir_handler.create_association( user_id, room_alias, room_id, servers ) - yield dir_handler.send_room_alias_update_event(user_id, room_id) + yield dir_handler.send_room_alias_update_event( + requester, + user_id, + room_id + ) except SynapseError as e: raise e except: diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index 3c5a212920..953764bd8e 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -51,7 +51,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): defer.returnValue((400, "Unable to parse name")) yield self.handlers.profile_handler.set_displayname( - user, requester.user, new_name) + user, requester, new_name) defer.returnValue((200, {})) @@ -88,7 +88,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): defer.returnValue((400, "Unable to parse name")) yield self.handlers.profile_handler.set_avatar_url( - user, requester.user, new_name) + user, requester, new_name) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index f5ed4f7302..cbf3673eff 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -158,12 +158,12 @@ class RoomStateEventRestServlet(ClientV1RestServlet): if event_type == EventTypes.Member: yield self.handlers.room_member_handler.send_membership_event( + requester, event, context, - is_guest=requester.is_guest, ) else: - yield msg_handler.send_nonmember_event(event, context) + yield msg_handler.send_nonmember_event(requester, event, context) defer.returnValue((200, {"event_id": event.event_id})) @@ -183,13 +183,13 @@ class RoomSendEventRestServlet(ClientV1RestServlet): msg_handler = self.handlers.message_handler event = yield msg_handler.create_and_send_nonmember_event( + requester, { "type": event_type, "content": content, "room_id": room_id, "sender": requester.user.to_string(), }, - token_id=requester.access_token_id, txn_id=txn_id, ) @@ -504,6 +504,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): msg_handler = self.handlers.message_handler event = yield msg_handler.create_and_send_nonmember_event( + requester, { "type": EventTypes.Redaction, "content": content, @@ -511,7 +512,6 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): "sender": requester.user.to_string(), "redacts": event_id, }, - token_id=requester.access_token_id, txn_id=txn_id, ) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index a87703bbfd..4f2c14e4ff 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -23,7 +23,7 @@ from synapse.api.errors import AuthError from synapse.handlers.profile import ProfileHandler from synapse.types import UserID -from tests.utils import setup_test_homeserver +from tests.utils import setup_test_homeserver, requester_for_user class ProfileHandlers(object): @@ -84,7 +84,11 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_name(self): - yield self.handler.set_displayname(self.frank, self.frank, "Frank Jr.") + yield self.handler.set_displayname( + self.frank, + requester_for_user(self.frank), + "Frank Jr." + ) self.assertEquals( (yield self.store.get_profile_displayname(self.frank.localpart)), @@ -93,7 +97,11 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_name_noauth(self): - d = self.handler.set_displayname(self.frank, self.bob, "Frank Jr.") + d = self.handler.set_displayname( + self.frank, + requester_for_user(self.bob), + "Frank Jr." + ) yield self.assertFailure(d, AuthError) @@ -136,7 +144,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_avatar(self): yield self.handler.set_avatar_url( - self.frank, self.frank, "http://my.server/pic.gif" + self.frank, requester_for_user(self.frank), "http://my.server/pic.gif" ) self.assertEquals( diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py index 38daaf87e2..daabc563b4 100644 --- a/tests/replication/test_resource.py +++ b/tests/replication/test_resource.py @@ -18,7 +18,7 @@ from synapse.types import Requester, UserID from twisted.internet import defer from tests import unittest -from tests.utils import setup_test_homeserver +from tests.utils import setup_test_homeserver, requester_for_user from mock import Mock, NonCallableMock import json import contextlib @@ -133,12 +133,15 @@ class ReplicationResourceCase(unittest.TestCase): @defer.inlineCallbacks def send_text_message(self, room_id, message): handler = self.hs.get_handlers().message_handler - event = yield handler.create_and_send_nonmember_event({ - "type": "m.room.message", - "content": {"body": "message", "msgtype": "m.text"}, - "room_id": room_id, - "sender": self.user.to_string(), - }) + event = yield handler.create_and_send_nonmember_event( + requester_for_user(self.user), + { + "type": "m.room.message", + "content": {"body": "message", "msgtype": "m.text"}, + "room_id": room_id, + "sender": self.user.to_string(), + } + ) defer.returnValue(event.event_id) @defer.inlineCallbacks diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 0785965de2..1d210f9bf8 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -86,7 +86,7 @@ class ProfileTestCase(unittest.TestCase): self.assertEquals(200, code) self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD") - self.assertEquals(mocked_set.call_args[0][1].localpart, "1234ABCD") + self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD") self.assertEquals(mocked_set.call_args[0][2], "Frank Jr.") @defer.inlineCallbacks @@ -155,5 +155,5 @@ class ProfileTestCase(unittest.TestCase): self.assertEquals(200, code) self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD") - self.assertEquals(mocked_set.call_args[0][1].localpart, "1234ABCD") + self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD") self.assertEquals(mocked_set.call_args[0][2], "http://my.server/pic.gif") diff --git a/tests/utils.py b/tests/utils.py index c67fa1ca35..291b549053 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,6 +20,7 @@ from synapse.storage.prepare_database import prepare_database from synapse.storage.engines import create_engine from synapse.server import HomeServer from synapse.federation.transport import server +from synapse.types import Requester from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.logcontext import LoggingContext @@ -510,3 +511,7 @@ class DeferredMockCallable(object): "call(%s)" % _format_call(c[0], c[1]) for c in calls ]) ) + + +def requester_for_user(user): + return Requester(user, None, False) -- cgit 1.4.1 From bb0e82fff186b509e82184ba679b7b8eb6db26c5 Mon Sep 17 00:00:00 2001 From: Patrik Oldsberg Date: Tue, 23 Feb 2016 14:48:12 +0100 Subject: tests/utils: added room_invite_state_types to test config Signed-off-by: Patrik Oldsberg --- tests/utils.py | 1 + 1 file changed, 1 insertion(+) (limited to 'tests/utils.py') diff --git a/tests/utils.py b/tests/utils.py index 291b549053..52405502e9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -51,6 +51,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): config.macaroon_secret_key = "not even a little secret" config.server_name = "server.under.test" config.trusted_third_party_id_servers = [] + config.room_invite_state_types = [] config.database_config = {"name": "sqlite3"} -- cgit 1.4.1