summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/handlers/sync.py103
-rw-r--r--synapse/http/server.py114
-rw-r--r--synapse/metrics/__init__.py16
-rw-r--r--synapse/storage/events.py2
-rw-r--r--synapse/storage/roommember.py27
5 files changed, 198 insertions, 64 deletions
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 56b86356f2..0f713ce038 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -235,10 +235,10 @@ class SyncHandler(object):
         defer.returnValue(rules)
 
     @defer.inlineCallbacks
-    def ephemeral_by_room(self, sync_config, now_token, since_token=None):
+    def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None):
         """Get the ephemeral events for each room the user is in
         Args:
-            sync_config (SyncConfig): The flags, filters and user for the sync.
+            sync_result_builder(SyncResultBuilder)
             now_token (StreamToken): Where the server is currently up to.
             since_token (StreamToken): Where the server was when the client
                 last synced.
@@ -248,10 +248,12 @@ class SyncHandler(object):
             typing events for that room.
         """
 
+        sync_config = sync_result_builder.sync_config
+
         with Measure(self.clock, "ephemeral_by_room"):
             typing_key = since_token.typing_key if since_token else "0"
 
-            room_ids = yield self.store.get_rooms_for_user(sync_config.user.to_string())
+            room_ids = sync_result_builder.joined_room_ids
 
             typing_source = self.event_sources.sources["typing"]
             typing, typing_key = yield typing_source.get_new_events(
@@ -565,10 +567,22 @@ class SyncHandler(object):
         # Always use the `now_token` in `SyncResultBuilder`
         now_token = yield self.event_sources.get_current_token()
 
+        user_id = sync_config.user.to_string()
+        app_service = self.store.get_app_service_by_user_id(user_id)
+        if app_service:
+            # We no longer support AS users using /sync directly.
+            # See https://github.com/matrix-org/matrix-doc/issues/1144
+            raise NotImplementedError()
+        else:
+            joined_room_ids = yield self.get_rooms_for_user_at(
+                user_id, now_token.room_stream_id,
+            )
+
         sync_result_builder = SyncResultBuilder(
             sync_config, full_state,
             since_token=since_token,
             now_token=now_token,
+            joined_room_ids=joined_room_ids,
         )
 
         account_data_by_room = yield self._generate_sync_entry_for_account_data(
@@ -603,7 +617,6 @@ class SyncHandler(object):
         device_id = sync_config.device_id
         one_time_key_counts = {}
         if device_id:
-            user_id = sync_config.user.to_string()
             one_time_key_counts = yield self.store.count_e2e_one_time_keys(
                 user_id, device_id
             )
@@ -891,7 +904,7 @@ class SyncHandler(object):
             ephemeral_by_room = {}
         else:
             now_token, ephemeral_by_room = yield self.ephemeral_by_room(
-                sync_result_builder.sync_config,
+                sync_result_builder,
                 now_token=sync_result_builder.now_token,
                 since_token=sync_result_builder.since_token,
             )
@@ -996,16 +1009,8 @@ class SyncHandler(object):
         if rooms_changed:
             defer.returnValue(True)
 
-        app_service = self.store.get_app_service_by_user_id(user_id)
-        if app_service:
-            # We no longer support AS users using /sync directly.
-            # See https://github.com/matrix-org/matrix-doc/issues/1144
-            raise NotImplementedError()
-        else:
-            joined_room_ids = yield self.store.get_rooms_for_user(user_id)
-
         stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream
-        for room_id in joined_room_ids:
+        for room_id in sync_result_builder.joined_room_ids:
             if self.store.has_room_changed_since(room_id, stream_id):
                 defer.returnValue(True)
         defer.returnValue(False)
@@ -1029,14 +1034,6 @@ class SyncHandler(object):
 
         assert since_token
 
-        app_service = self.store.get_app_service_by_user_id(user_id)
-        if app_service:
-            # We no longer support AS users using /sync directly.
-            # See https://github.com/matrix-org/matrix-doc/issues/1144
-            raise NotImplementedError()
-        else:
-            joined_room_ids = yield self.store.get_rooms_for_user(user_id)
-
         # Get a list of membership change events that have happened.
         rooms_changed = yield self.store.get_membership_changes_for_user(
             user_id, since_token.room_key, now_token.room_key
@@ -1059,7 +1056,7 @@ class SyncHandler(object):
             # we do send down the room, and with full state, where necessary
 
             old_state_ids = None
-            if room_id in joined_room_ids and non_joins:
+            if room_id in sync_result_builder.joined_room_ids and non_joins:
                 # Always include if the user (re)joined the room, especially
                 # important so that device list changes are calculated correctly.
                 # If there are non join member events, but we are still in the room,
@@ -1069,7 +1066,7 @@ class SyncHandler(object):
                 # User is in the room so we don't need to do the invite/leave checks
                 continue
 
-            if room_id in joined_room_ids or has_join:
+            if room_id in sync_result_builder.joined_room_ids or has_join:
                 old_state_ids = yield self.get_state_at(room_id, since_token)
                 old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
                 old_mem_ev = None
@@ -1081,7 +1078,7 @@ class SyncHandler(object):
                     newly_joined_rooms.append(room_id)
 
             # If user is in the room then we don't need to do the invite/leave checks
-            if room_id in joined_room_ids:
+            if room_id in sync_result_builder.joined_room_ids:
                 continue
 
             if not non_joins:
@@ -1148,7 +1145,7 @@ class SyncHandler(object):
 
         # Get all events for rooms we're currently joined to.
         room_to_events = yield self.store.get_room_events_stream_for_rooms(
-            room_ids=joined_room_ids,
+            room_ids=sync_result_builder.joined_room_ids,
             from_key=since_token.room_key,
             to_key=now_token.room_key,
             limit=timeline_limit + 1,
@@ -1156,7 +1153,7 @@ class SyncHandler(object):
 
         # We loop through all room ids, even if there are no new events, in case
         # there are non room events taht we need to notify about.
-        for room_id in joined_room_ids:
+        for room_id in sync_result_builder.joined_room_ids:
             room_entry = room_to_events.get(room_id, None)
 
             if room_entry:
@@ -1364,6 +1361,54 @@ class SyncHandler(object):
         else:
             raise Exception("Unrecognized rtype: %r", room_builder.rtype)
 
+    @defer.inlineCallbacks
+    def get_rooms_for_user_at(self, user_id, stream_ordering):
+        """Get set of joined rooms for a user at the given stream ordering.
+
+        The stream ordering *must* be recent, otherwise this may throw an
+        exception if older than a month. (This function is called with the
+        current token, which should be perfectly fine).
+
+        Args:
+            user_id (str)
+            stream_ordering (int)
+
+        ReturnValue:
+            Deferred[frozenset[str]]: Set of room_ids the user is in at given
+            stream_ordering.
+        """
+        joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(
+            user_id,
+        )
+
+        joined_room_ids = set()
+
+        # We need to check that the stream ordering of the join for each room
+        # is before the stream_ordering asked for. This might not be the case
+        # if the user joins a room between us getting the current token and
+        # calling `get_rooms_for_user_with_stream_ordering`.
+        # If the membership's stream ordering is after the given stream
+        # ordering, we need to go and work out if the user was in the room
+        # before.
+        for room_id, membership_stream_ordering in joined_rooms:
+            if membership_stream_ordering <= stream_ordering:
+                joined_room_ids.add(room_id)
+                continue
+
+            logger.info("User joined room after current token: %s", room_id)
+
+            extrems = yield self.store.get_forward_extremeties_for_room(
+                room_id, stream_ordering,
+            )
+            users_in_room = yield self.state.get_current_user_in_room(
+                room_id, extrems,
+            )
+            if user_id in users_in_room:
+                joined_room_ids.add(room_id)
+
+        joined_room_ids = frozenset(joined_room_ids)
+        defer.returnValue(joined_room_ids)
+
 
 def _action_has_highlight(actions):
     for action in actions:
@@ -1413,7 +1458,8 @@ def _calculate_state(timeline_contains, timeline_start, previous, current):
 
 class SyncResultBuilder(object):
     "Used to help build up a new SyncResult for a user"
-    def __init__(self, sync_config, full_state, since_token, now_token):
+    def __init__(self, sync_config, full_state, since_token, now_token,
+                 joined_room_ids):
         """
         Args:
             sync_config(SyncConfig)
@@ -1425,6 +1471,7 @@ class SyncResultBuilder(object):
         self.full_state = full_state
         self.since_token = since_token
         self.now_token = now_token
+        self.joined_room_ids = joined_room_ids
 
         self.presence = []
         self.account_data = []
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 165c684d0d..4b567215c8 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -59,6 +60,11 @@ response_count = metrics.register_counter(
     )
 )
 
+requests_counter = metrics.register_counter(
+    "requests_received",
+    labels=["method", "servlet", ],
+)
+
 outgoing_responses_counter = metrics.register_counter(
     "responses",
     labels=["method", "code"],
@@ -145,7 +151,8 @@ def wrap_request_handler(request_handler, include_metrics=False):
                 # at the servlet name. For most requests that name will be
                 # JsonResource (or a subclass), and JsonResource._async_render
                 # will update it once it picks a servlet.
-                request_metrics.start(self.clock, name=self.__class__.__name__)
+                servlet_name = self.__class__.__name__
+                request_metrics.start(self.clock, name=servlet_name)
 
                 request_context.request = request_id
                 with request.processing():
@@ -154,6 +161,7 @@ def wrap_request_handler(request_handler, include_metrics=False):
                             if include_metrics:
                                 yield request_handler(self, request, request_metrics)
                             else:
+                                requests_counter.inc(request.method, servlet_name)
                                 yield request_handler(self, request)
                     except CodeMessageException as e:
                         code = e.code
@@ -229,7 +237,7 @@ class JsonResource(HttpServer, resource.Resource):
     """ This implements the HttpServer interface and provides JSON support for
     Resources.
 
-    Register callbacks via register_path()
+    Register callbacks via register_paths()
 
     Callbacks can return a tuple of status code and a dict in which case the
     the dict will automatically be sent to the client as a JSON object.
@@ -276,49 +284,59 @@ class JsonResource(HttpServer, resource.Resource):
             This checks if anyone has registered a callback for that method and
             path.
         """
-        if request.method == "OPTIONS":
-            self._send_response(request, 200, {})
-            return
+        callback, group_dict = self._get_handler_for_request(request)
 
-        # Loop through all the registered callbacks to check if the method
-        # and path regex match
-        for path_entry in self.path_regexs.get(request.method, []):
-            m = path_entry.pattern.match(request.path)
-            if not m:
-                continue
+        servlet_instance = getattr(callback, "__self__", None)
+        if servlet_instance is not None:
+            servlet_classname = servlet_instance.__class__.__name__
+        else:
+            servlet_classname = "%r" % callback
 
-            # We found a match! First update the metrics object to indicate
-            # which servlet is handling the request.
+        request_metrics.name = servlet_classname
+        requests_counter.inc(request.method, servlet_classname)
 
-            callback = path_entry.callback
+        # Now trigger the callback. If it returns a response, we send it
+        # here. If it throws an exception, that is handled by the wrapper
+        # installed by @request_handler.
 
-            servlet_instance = getattr(callback, "__self__", None)
-            if servlet_instance is not None:
-                servlet_classname = servlet_instance.__class__.__name__
-            else:
-                servlet_classname = "%r" % callback
+        kwargs = intern_dict({
+            name: urllib.unquote(value).decode("UTF-8") if value else value
+            for name, value in group_dict.items()
+        })
 
-            request_metrics.name = servlet_classname
+        callback_return = yield callback(request, **kwargs)
+        if callback_return is not None:
+            code, response = callback_return
+            self._send_response(request, code, response)
 
-            # Now trigger the callback. If it returns a response, we send it
-            # here. If it throws an exception, that is handled by the wrapper
-            # installed by @request_handler.
+    def _get_handler_for_request(self, request):
+        """Finds a callback method to handle the given request
 
-            kwargs = intern_dict({
-                name: urllib.unquote(value).decode("UTF-8") if value else value
-                for name, value in m.groupdict().items()
-            })
+        Args:
+            request (twisted.web.http.Request):
+
+        Returns:
+            Tuple[Callable, dict[str, str]]: callback method, and the dict
+                mapping keys to path components as specified in the handler's
+                path match regexp.
 
-            callback_return = yield callback(request, **kwargs)
-            if callback_return is not None:
-                code, response = callback_return
-                self._send_response(request, code, response)
+                The callback will normally be a method registered via
+                register_paths, so will return (possibly via Deferred) either
+                None, or a tuple of (http code, response body).
+        """
+        if request.method == "OPTIONS":
+            return _options_handler, {}
 
-            return
+        # Loop through all the registered callbacks to check if the method
+        # and path regex match
+        for path_entry in self.path_regexs.get(request.method, []):
+            m = path_entry.pattern.match(request.path)
+            if m:
+                # We found a match!
+                return path_entry.callback, m.groupdict()
 
         # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
-        request_metrics.name = self.__class__.__name__ + ".UnrecognizedRequest"
-        raise UnrecognizedRequestError()
+        return _unrecognised_request_handler, {}
 
     def _send_response(self, request, code, response_json_object,
                        response_code_message=None):
@@ -335,6 +353,34 @@ class JsonResource(HttpServer, resource.Resource):
         )
 
 
+def _options_handler(request):
+    """Request handler for OPTIONS requests
+
+    This is a request handler suitable for return from
+    _get_handler_for_request. It returns a 200 and an empty body.
+
+    Args:
+        request (twisted.web.http.Request):
+
+    Returns:
+        Tuple[int, dict]: http code, response body.
+    """
+    return 200, {}
+
+
+def _unrecognised_request_handler(request):
+    """Request handler for unrecognised requests
+
+    This is a request handler suitable for return from
+    _get_handler_for_request. It actually just raises an
+    UnrecognizedRequestError.
+
+    Args:
+        request (twisted.web.http.Request):
+    """
+    raise UnrecognizedRequestError()
+
+
 class RequestMetrics(object):
     def start(self, clock, name):
         self.start = clock.time_msec()
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index e0cfb7d08f..50d99d7a5c 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -57,15 +57,31 @@ class Metrics(object):
         return metric
 
     def register_counter(self, *args, **kwargs):
+        """
+        Returns:
+            CounterMetric
+        """
         return self._register(CounterMetric, *args, **kwargs)
 
     def register_callback(self, *args, **kwargs):
+        """
+        Returns:
+            CallbackMetric
+        """
         return self._register(CallbackMetric, *args, **kwargs)
 
     def register_distribution(self, *args, **kwargs):
+        """
+        Returns:
+            DistributionMetric
+        """
         return self._register(DistributionMetric, *args, **kwargs)
 
     def register_cache(self, *args, **kwargs):
+        """
+        Returns:
+            CacheMetric
+        """
         return self._register(CacheMetric, *args, **kwargs)
 
 
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 057b1be4d5..826fad307e 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -754,7 +754,7 @@ class EventsStore(EventsWorkerStore):
 
                 for member in members_changed:
                     self._invalidate_cache_and_stream(
-                        txn, self.get_rooms_for_user, (member,)
+                        txn, self.get_rooms_for_user_with_stream_ordering, (member,)
                     )
 
                 for host in set(get_domain_from_id(u) for u in members_changed):
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index d79877dac7..52e19e16b0 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -38,6 +38,11 @@ RoomsForUser = namedtuple(
     ("room_id", "sender", "membership", "event_id", "stream_ordering")
 )
 
+GetRoomsForUserWithStreamOrdering = namedtuple(
+    "_GetRoomsForUserWithStreamOrdering",
+    ("room_id", "stream_ordering",)
+)
+
 
 # We store this using a namedtuple so that we save about 3x space over using a
 # dict.
@@ -181,12 +186,32 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return results
 
     @cachedInlineCallbacks(max_entries=500000, iterable=True)
-    def get_rooms_for_user(self, user_id):
+    def get_rooms_for_user_with_stream_ordering(self, user_id):
         """Returns a set of room_ids the user is currently joined to
+
+        Args:
+            user_id (str)
+
+        Returns:
+            Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
+            the rooms the user is in currently, along with the stream ordering
+            of the most recent join for that user and room.
         """
         rooms = yield self.get_rooms_for_user_where_membership_is(
             user_id, membership_list=[Membership.JOIN],
         )
+        defer.returnValue(frozenset(
+            GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
+            for r in rooms
+        ))
+
+    @defer.inlineCallbacks
+    def get_rooms_for_user(self, user_id, on_invalidate=None):
+        """Returns a set of room_ids the user is currently joined to
+        """
+        rooms = yield self.get_rooms_for_user_with_stream_ordering(
+            user_id, on_invalidate=on_invalidate,
+        )
         defer.returnValue(frozenset(r.room_id for r in rooms))
 
     @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)