diff options
-rwxr-xr-x | scripts-dev/definitions.py | 23 | ||||
-rw-r--r-- | synapse/handlers/auth.py | 14 | ||||
-rw-r--r-- | synapse/handlers/events.py | 70 | ||||
-rw-r--r-- | synapse/handlers/presence.py | 8 | ||||
-rw-r--r-- | synapse/handlers/register.py | 12 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/register.py | 18 | ||||
-rw-r--r-- | synapse/storage/roommember.py | 24 | ||||
-rw-r--r-- | tests/storage/test_roommember.py | 10 |
8 files changed, 60 insertions, 119 deletions
diff --git a/scripts-dev/definitions.py b/scripts-dev/definitions.py index 8340c72618..47dac7772d 100755 --- a/scripts-dev/definitions.py +++ b/scripts-dev/definitions.py @@ -86,9 +86,12 @@ def used_names(prefix, item, defs, names): for name, funcs in defs.get('class', {}).items(): used_names(prefix + name + ".", name, funcs, names) + path = prefix.rstrip('.') for used in defs.get('uses', ()): if used in names: - names[used].setdefault('used', {}).setdefault(item, []).append(prefix.rstrip('.')) + if item: + names[item].setdefault('uses', []).append(used) + names[used].setdefault('used', {}).setdefault(item, []).append(path) if __name__ == '__main__': @@ -114,6 +117,10 @@ if __name__ == '__main__': help="Include referrers up to the given depth" ) parser.add_argument( + "--referred", default=0, type=int, + help="Include referred down to the given depth" + ) + parser.add_argument( "--format", default="yaml", help="Output format, one of 'yaml' or 'dot'" ) @@ -161,6 +168,20 @@ if __name__ == '__main__': continue result[name] = definition + referred_depth = args.referred + referred = set() + while referred_depth: + referred_depth -= 1 + for entry in result.values(): + for uses in entry.get("uses", ()): + referred.add(uses) + for name, definition in names.items(): + if not name in referred: + continue + if ignore and any(pattern.match(name) for pattern in ignore): + continue + result[name] = definition + if args.format == 'yaml': yaml.dump(result, sys.stdout, default_flow_style=False) elif args.format == 'dot': diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index d7233cd0d6..82d458b424 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -160,6 +160,20 @@ class AuthHandler(BaseHandler): defer.returnValue(True) defer.returnValue(False) + def get_session_id(self, clientdict): + """ + Gets the session ID for a client given the client dictionary + :param clientdict: The dictionary sent by the client in the request + :return: The string session ID the client sent. If the client did not + send a session ID, returns None. + """ + sid = None + if clientdict and 'auth' in clientdict: + authdict = clientdict['auth'] + if 'session' in authdict: + sid = authdict['session'] + return sid + def set_session_data(self, session_id, key, value): """ Store a key-value pair into the sessions data associated with this diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 72a31a9755..f25a252523 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -18,7 +18,6 @@ from twisted.internet import defer from synapse.util.logutils import log_function from synapse.types import UserID from synapse.events.utils import serialize_event -from synapse.util.logcontext import preserve_context_over_fn from synapse.api.constants import Membership, EventTypes from synapse.events import EventBase @@ -31,20 +30,6 @@ import random logger = logging.getLogger(__name__) -def started_user_eventstream(distributor, user): - return preserve_context_over_fn( - distributor.fire, - "started_user_eventstream", user - ) - - -def stopped_user_eventstream(distributor, user): - return preserve_context_over_fn( - distributor.fire, - "stopped_user_eventstream", user - ) - - class EventStreamHandler(BaseHandler): def __init__(self, hs): @@ -64,61 +49,6 @@ class EventStreamHandler(BaseHandler): self.notifier = hs.get_notifier() @defer.inlineCallbacks - def started_stream(self, user): - """Tells the presence handler that we have started an eventstream for - the user: - - Args: - user (User): The user who started a stream. - Returns: - A deferred that completes once their presence has been updated. - """ - if user not in self._streams_per_user: - # Make sure we set the streams per user to 1 here rather than - # setting it to zero and incrementing the value below. - # Otherwise this may race with stopped_stream causing the - # user to be erased from the map before we have a chance - # to increment it. - self._streams_per_user[user] = 1 - if user in self._stop_timer_per_user: - try: - self.clock.cancel_call_later( - self._stop_timer_per_user.pop(user) - ) - except: - logger.exception("Failed to cancel event timer") - else: - yield started_user_eventstream(self.distributor, user) - else: - self._streams_per_user[user] += 1 - - def stopped_stream(self, user): - """If there are no streams for a user this starts a timer that will - notify the presence handler that we haven't got an event stream for - the user unless the user starts a new stream in 30 seconds. - - Args: - user (User): The user who stopped a stream. - """ - self._streams_per_user[user] -= 1 - if not self._streams_per_user[user]: - del self._streams_per_user[user] - - # 30 seconds of grace to allow the client to reconnect again - # before we think they're gone - def _later(): - logger.debug("_later stopped_user_eventstream %s", user) - - self._stop_timer_per_user.pop(user, None) - - return stopped_user_eventstream(self.distributor, user) - - logger.debug("Scheduling _later: for %s", user) - self._stop_timer_per_user[user] = ( - self.clock.call_later(30, _later) - ) - - @defer.inlineCallbacks @log_function def get_stream(self, auth_user_id, pagin_config, timeout=0, as_client_event=True, affect_presence=True, diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index f6cf343174..d0c8f1328b 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -73,14 +73,6 @@ FEDERATION_PING_INTERVAL = 25 * 60 * 1000 assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER -def user_presence_changed(distributor, user, statuscache): - return distributor.fire("user_presence_changed", user, statuscache) - - -def collect_presencelike_data(distributor, user, content): - return distributor.fire("collect_presencelike_data", user, content) - - class PresenceHandler(BaseHandler): def __init__(self, hs): diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 6ffb8c0da6..f287ee247b 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -47,7 +47,8 @@ class RegistrationHandler(BaseHandler): self._next_generated_user_id = None @defer.inlineCallbacks - def check_username(self, localpart, guest_access_token=None): + def check_username(self, localpart, guest_access_token=None, + assigned_user_id=None): yield run_on_reactor() if urllib.quote(localpart.encode('utf-8')) != localpart: @@ -60,6 +61,15 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() + if assigned_user_id: + if user_id == assigned_user_id: + return + else: + raise SynapseError( + 400, + "A different user ID has already been registered for this session", + ) + yield self.check_user_id_not_appservice_exclusive(user_id) users = yield self.store.get_users_by_id_case_insensitive(user_id) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index c440430e25..d32c06c882 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -122,10 +122,22 @@ class RegisterRestServlet(RestServlet): guest_access_token = body.get("guest_access_token", None) + session_id = self.auth_handler.get_session_id(body) + registered_user_id = None + if session_id: + # if we get a registered user id out of here, it means we previously + # registered a user for this session, so we could just return the + # user here. We carry on and go through the auth checks though, + # for paranoia. + registered_user_id = self.auth_handler.get_session_data( + session_id, "registered_user_id", None + ) + if desired_username is not None: yield self.registration_handler.check_username( desired_username, - guest_access_token=guest_access_token + guest_access_token=guest_access_token, + assigned_user_id=registered_user_id, ) if self.hs.config.enable_registration_captcha: @@ -147,10 +159,6 @@ class RegisterRestServlet(RestServlet): defer.returnValue((401, result)) return - # have we already registered a user for this session - registered_user_id = self.auth_handler.get_session_data( - session_id, "registered_user_id", None - ) if registered_user_id is not None: logger.info( "Already registered user ID %r for this session", diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 3065b0c1a5..0cd89260f2 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -252,30 +252,6 @@ class RoomMemberStore(SQLBaseStore): ) @defer.inlineCallbacks - def user_rooms_intersect(self, user_id_list): - """ Checks whether all the users whose IDs are given in a list share a - room. - - This is a "hot path" function that's called a lot, e.g. by presence for - generating the event stream. As such, it is implemented locally by - wrapping logic around heavily-cached database queries. - """ - if len(user_id_list) < 2: - defer.returnValue(True) - - deferreds = [self.get_rooms_for_user(u) for u in user_id_list] - - results = yield defer.DeferredList(deferreds, consumeErrors=True) - - # A list of sets of strings giving room IDs for each user - room_id_lists = [set([r.room_id for r in result[1]]) for result in results] - - # There isn't a setintersection(*list_of_sets) - ret = len(room_id_lists.pop(0).intersection(*room_id_lists)) > 0 - - defer.returnValue(ret) - - @defer.inlineCallbacks def forget(self, user_id, room_id): """Indicate that user_id wishes to discard history for room_id.""" def f(txn): diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 677d11f68d..b029ff0584 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -91,11 +91,6 @@ class RoomMemberStoreTestCase(unittest.TestCase): ) )] ) - self.assertFalse( - (yield self.store.user_rooms_intersect( - [self.u_alice.to_string(), self.u_bob.to_string()] - )) - ) @defer.inlineCallbacks def test_two_members(self): @@ -108,11 +103,6 @@ class RoomMemberStoreTestCase(unittest.TestCase): yield self.store.get_room_members(self.room.to_string()) )} ) - self.assertTrue(( - yield self.store.user_rooms_intersect([ - self.u_alice.to_string(), self.u_bob.to_string() - ]) - )) @defer.inlineCallbacks def test_room_hosts(self): |