summary refs log tree commit diff
diff options
context:
space:
mode:
-rwxr-xr-xscripts-dev/definitions.py23
-rw-r--r--synapse/handlers/auth.py14
-rw-r--r--synapse/handlers/events.py70
-rw-r--r--synapse/handlers/presence.py8
-rw-r--r--synapse/handlers/register.py12
-rw-r--r--synapse/rest/client/v2_alpha/register.py18
-rw-r--r--synapse/storage/roommember.py24
-rw-r--r--tests/storage/test_roommember.py10
8 files changed, 60 insertions, 119 deletions
diff --git a/scripts-dev/definitions.py b/scripts-dev/definitions.py
index 8340c72618..47dac7772d 100755
--- a/scripts-dev/definitions.py
+++ b/scripts-dev/definitions.py
@@ -86,9 +86,12 @@ def used_names(prefix, item, defs, names):
     for name, funcs in defs.get('class', {}).items():
         used_names(prefix + name + ".", name, funcs, names)
 
+    path = prefix.rstrip('.')
     for used in defs.get('uses', ()):
         if used in names:
-            names[used].setdefault('used', {}).setdefault(item, []).append(prefix.rstrip('.'))
+            if item:
+                names[item].setdefault('uses', []).append(used)
+            names[used].setdefault('used', {}).setdefault(item, []).append(path)
 
 
 if __name__ == '__main__':
@@ -114,6 +117,10 @@ if __name__ == '__main__':
         help="Include referrers up to the given depth"
     )
     parser.add_argument(
+        "--referred", default=0, type=int,
+        help="Include referred down to the given depth"
+    )
+    parser.add_argument(
         "--format", default="yaml",
         help="Output format, one of 'yaml' or 'dot'"
     )
@@ -161,6 +168,20 @@ if __name__ == '__main__':
                 continue
             result[name] = definition
 
+    referred_depth = args.referred
+    referred = set()
+    while referred_depth:
+        referred_depth -= 1
+        for entry in result.values():
+            for uses in entry.get("uses", ()):
+                referred.add(uses)
+        for name, definition in names.items():
+            if not name in referred:
+                continue
+            if ignore and any(pattern.match(name) for pattern in ignore):
+                continue
+            result[name] = definition
+
     if args.format == 'yaml':
         yaml.dump(result, sys.stdout, default_flow_style=False)
     elif args.format == 'dot':
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index d7233cd0d6..82d458b424 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -160,6 +160,20 @@ class AuthHandler(BaseHandler):
             defer.returnValue(True)
         defer.returnValue(False)
 
+    def get_session_id(self, clientdict):
+        """
+        Gets the session ID for a client given the client dictionary
+        :param clientdict: The dictionary sent by the client in the request
+        :return: The string session ID the client sent. If the client did not
+                 send a session ID, returns None.
+        """
+        sid = None
+        if clientdict and 'auth' in clientdict:
+            authdict = clientdict['auth']
+            if 'session' in authdict:
+                sid = authdict['session']
+        return sid
+
     def set_session_data(self, session_id, key, value):
         """
         Store a key-value pair into the sessions data associated with this
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 72a31a9755..f25a252523 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -18,7 +18,6 @@ from twisted.internet import defer
 from synapse.util.logutils import log_function
 from synapse.types import UserID
 from synapse.events.utils import serialize_event
-from synapse.util.logcontext import preserve_context_over_fn
 from synapse.api.constants import Membership, EventTypes
 from synapse.events import EventBase
 
@@ -31,20 +30,6 @@ import random
 logger = logging.getLogger(__name__)
 
 
-def started_user_eventstream(distributor, user):
-    return preserve_context_over_fn(
-        distributor.fire,
-        "started_user_eventstream", user
-    )
-
-
-def stopped_user_eventstream(distributor, user):
-    return preserve_context_over_fn(
-        distributor.fire,
-        "stopped_user_eventstream", user
-    )
-
-
 class EventStreamHandler(BaseHandler):
 
     def __init__(self, hs):
@@ -64,61 +49,6 @@ class EventStreamHandler(BaseHandler):
         self.notifier = hs.get_notifier()
 
     @defer.inlineCallbacks
-    def started_stream(self, user):
-        """Tells the presence handler that we have started an eventstream for
-        the user:
-
-        Args:
-            user (User): The user who started a stream.
-        Returns:
-            A deferred that completes once their presence has been updated.
-        """
-        if user not in self._streams_per_user:
-            # Make sure we set the streams per user to 1 here rather than
-            # setting it to zero and incrementing the value below.
-            # Otherwise this may race with stopped_stream causing the
-            # user to be erased from the map before we have a chance
-            # to increment it.
-            self._streams_per_user[user] = 1
-            if user in self._stop_timer_per_user:
-                try:
-                    self.clock.cancel_call_later(
-                        self._stop_timer_per_user.pop(user)
-                    )
-                except:
-                    logger.exception("Failed to cancel event timer")
-            else:
-                yield started_user_eventstream(self.distributor, user)
-        else:
-            self._streams_per_user[user] += 1
-
-    def stopped_stream(self, user):
-        """If there are no streams for a user this starts a timer that will
-        notify the presence handler that we haven't got an event stream for
-        the user unless the user starts a new stream in 30 seconds.
-
-        Args:
-            user (User): The user who stopped a stream.
-        """
-        self._streams_per_user[user] -= 1
-        if not self._streams_per_user[user]:
-            del self._streams_per_user[user]
-
-            # 30 seconds of grace to allow the client to reconnect again
-            #   before we think they're gone
-            def _later():
-                logger.debug("_later stopped_user_eventstream %s", user)
-
-                self._stop_timer_per_user.pop(user, None)
-
-                return stopped_user_eventstream(self.distributor, user)
-
-            logger.debug("Scheduling _later: for %s", user)
-            self._stop_timer_per_user[user] = (
-                self.clock.call_later(30, _later)
-            )
-
-    @defer.inlineCallbacks
     @log_function
     def get_stream(self, auth_user_id, pagin_config, timeout=0,
                    as_client_event=True, affect_presence=True,
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index f6cf343174..d0c8f1328b 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -73,14 +73,6 @@ FEDERATION_PING_INTERVAL = 25 * 60 * 1000
 assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
 
 
-def user_presence_changed(distributor, user, statuscache):
-    return distributor.fire("user_presence_changed", user, statuscache)
-
-
-def collect_presencelike_data(distributor, user, content):
-    return distributor.fire("collect_presencelike_data", user, content)
-
-
 class PresenceHandler(BaseHandler):
 
     def __init__(self, hs):
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 6ffb8c0da6..f287ee247b 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -47,7 +47,8 @@ class RegistrationHandler(BaseHandler):
         self._next_generated_user_id = None
 
     @defer.inlineCallbacks
-    def check_username(self, localpart, guest_access_token=None):
+    def check_username(self, localpart, guest_access_token=None,
+                       assigned_user_id=None):
         yield run_on_reactor()
 
         if urllib.quote(localpart.encode('utf-8')) != localpart:
@@ -60,6 +61,15 @@ class RegistrationHandler(BaseHandler):
         user = UserID(localpart, self.hs.hostname)
         user_id = user.to_string()
 
+        if assigned_user_id:
+            if user_id == assigned_user_id:
+                return
+            else:
+                raise SynapseError(
+                    400,
+                    "A different user ID has already been registered for this session",
+                )
+
         yield self.check_user_id_not_appservice_exclusive(user_id)
 
         users = yield self.store.get_users_by_id_case_insensitive(user_id)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index c440430e25..d32c06c882 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -122,10 +122,22 @@ class RegisterRestServlet(RestServlet):
 
         guest_access_token = body.get("guest_access_token", None)
 
+        session_id = self.auth_handler.get_session_id(body)
+        registered_user_id = None
+        if session_id:
+            # if we get a registered user id out of here, it means we previously
+            # registered a user for this session, so we could just return the
+            # user here. We carry on and go through the auth checks though,
+            # for paranoia.
+            registered_user_id = self.auth_handler.get_session_data(
+                session_id, "registered_user_id", None
+            )
+
         if desired_username is not None:
             yield self.registration_handler.check_username(
                 desired_username,
-                guest_access_token=guest_access_token
+                guest_access_token=guest_access_token,
+                assigned_user_id=registered_user_id,
             )
 
         if self.hs.config.enable_registration_captcha:
@@ -147,10 +159,6 @@ class RegisterRestServlet(RestServlet):
             defer.returnValue((401, result))
             return
 
-        # have we already registered a user for this session
-        registered_user_id = self.auth_handler.get_session_data(
-            session_id, "registered_user_id", None
-        )
         if registered_user_id is not None:
             logger.info(
                 "Already registered user ID %r for this session",
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 3065b0c1a5..0cd89260f2 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -252,30 +252,6 @@ class RoomMemberStore(SQLBaseStore):
         )
 
     @defer.inlineCallbacks
-    def user_rooms_intersect(self, user_id_list):
-        """ Checks whether all the users whose IDs are given in a list share a
-        room.
-
-        This is a "hot path" function that's called a lot, e.g. by presence for
-        generating the event stream. As such, it is implemented locally by
-        wrapping logic around heavily-cached database queries.
-        """
-        if len(user_id_list) < 2:
-            defer.returnValue(True)
-
-        deferreds = [self.get_rooms_for_user(u) for u in user_id_list]
-
-        results = yield defer.DeferredList(deferreds, consumeErrors=True)
-
-        # A list of sets of strings giving room IDs for each user
-        room_id_lists = [set([r.room_id for r in result[1]]) for result in results]
-
-        # There isn't a setintersection(*list_of_sets)
-        ret = len(room_id_lists.pop(0).intersection(*room_id_lists)) > 0
-
-        defer.returnValue(ret)
-
-    @defer.inlineCallbacks
     def forget(self, user_id, room_id):
         """Indicate that user_id wishes to discard history for room_id."""
         def f(txn):
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 677d11f68d..b029ff0584 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -91,11 +91,6 @@ class RoomMemberStoreTestCase(unittest.TestCase):
                 )
             )]
         )
-        self.assertFalse(
-            (yield self.store.user_rooms_intersect(
-                [self.u_alice.to_string(), self.u_bob.to_string()]
-            ))
-        )
 
     @defer.inlineCallbacks
     def test_two_members(self):
@@ -108,11 +103,6 @@ class RoomMemberStoreTestCase(unittest.TestCase):
                 yield self.store.get_room_members(self.room.to_string())
             )}
         )
-        self.assertTrue((
-            yield self.store.user_rooms_intersect([
-                self.u_alice.to_string(), self.u_bob.to_string()
-            ])
-        ))
 
     @defer.inlineCallbacks
     def test_room_hosts(self):