diff options
-rw-r--r-- | synapse/handlers/sync.py | 269 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/sync.py | 119 | ||||
-rw-r--r-- | synapse/storage/stream.py | 20 | ||||
-rw-r--r-- | synapse/streams/config.py | 14 | ||||
-rw-r--r-- | synapse/types.py | 74 |
5 files changed, 450 insertions, 46 deletions
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index be26a491ff..451182cfec 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -20,6 +20,7 @@ from synapse.util.metrics import Measure from synapse.util.caches.response_cache import ResponseCache from synapse.push.clientformat import format_push_rules_for_user from synapse.visibility import filter_events_for_client +from synapse.types import SyncNextBatchToken, SyncPaginationState from twisted.internet import defer @@ -35,6 +36,22 @@ SyncConfig = collections.namedtuple("SyncConfig", [ "filter_collection", "is_guest", "request_key", + "pagination_config", +]) + + +SyncPaginationConfig = collections.namedtuple("SyncPaginationConfig", [ + "order", + "limit", +]) + +SYNC_PAGINATION_ORDER_TS = "o" +SYNC_PAGINATION_VALID_ORDERS = (SYNC_PAGINATION_ORDER_TS,) + + +SyncExtras = collections.namedtuple("SyncExtras", [ + "paginate", + "rooms", ]) @@ -113,6 +130,7 @@ class SyncResult(collections.namedtuple("SyncResult", [ "joined", # JoinedSyncResult for each joined room. "invited", # InvitedSyncResult for each invited room. "archived", # ArchivedSyncResult for each archived room. + "pagination_info", ])): __slots__ = [] @@ -140,8 +158,8 @@ class SyncHandler(object): self.clock = hs.get_clock() self.response_cache = ResponseCache() - def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, - full_state=False): + def wait_for_sync_for_user(self, sync_config, batch_token=None, timeout=0, + full_state=False, extras=None): """Get the sync for a client if we have new data for it now. Otherwise wait for new data to arrive on the server. If the timeout expires, then return an empty sync result. @@ -153,48 +171,42 @@ class SyncHandler(object): result = self.response_cache.set( sync_config.request_key, self._wait_for_sync_for_user( - sync_config, since_token, timeout, full_state + sync_config, batch_token, timeout, full_state, extras, ) ) return result @defer.inlineCallbacks - def _wait_for_sync_for_user(self, sync_config, since_token, timeout, - full_state): + def _wait_for_sync_for_user(self, sync_config, batch_token, timeout, + full_state, extras=None): context = LoggingContext.current_context() if context: - if since_token is None: + if batch_token is None: context.tag = "initial_sync" elif full_state: context.tag = "full_state_sync" else: context.tag = "incremental_sync" - if timeout == 0 or since_token is None or full_state: + if timeout == 0 or batch_token is None or full_state: # we are going to return immediately, so don't bother calling # notifier.wait_for_events. - result = yield self.current_sync_for_user( - sync_config, since_token, full_state=full_state, + result = yield self.generate_sync_result( + sync_config, batch_token, full_state=full_state, extras=extras, ) defer.returnValue(result) else: def current_sync_callback(before_token, after_token): - return self.current_sync_for_user(sync_config, since_token) + return self.generate_sync_result( + sync_config, batch_token, full_state=False, extras=extras, + ) result = yield self.notifier.wait_for_events( sync_config.user.to_string(), timeout, current_sync_callback, - from_token=since_token, + from_token=batch_token.stream_token, ) defer.returnValue(result) - def current_sync_for_user(self, sync_config, since_token=None, - full_state=False): - """Get the sync for client needed to match what the server has now. - Returns: - A Deferred SyncResult. - """ - return self.generate_sync_result(sync_config, since_token, full_state) - @defer.inlineCallbacks def push_rules_for_user(self, user): user_id = user.to_string() @@ -490,7 +502,8 @@ class SyncHandler(object): defer.returnValue(None) @defer.inlineCallbacks - def generate_sync_result(self, sync_config, since_token=None, full_state=False): + def generate_sync_result(self, sync_config, batch_token=None, full_state=False, + extras=None): """Generates a sync result. Args: @@ -510,7 +523,7 @@ class SyncHandler(object): sync_result_builder = SyncResultBuilder( sync_config, full_state, - since_token=since_token, + batch_token=batch_token, now_token=now_token, ) @@ -519,7 +532,7 @@ class SyncHandler(object): ) res = yield self._generate_sync_entry_for_rooms( - sync_result_builder, account_data_by_room + sync_result_builder, account_data_by_room, extras, ) newly_joined_rooms, newly_joined_users = res @@ -533,7 +546,11 @@ class SyncHandler(object): joined=sync_result_builder.joined, invited=sync_result_builder.invited, archived=sync_result_builder.archived, - next_batch=sync_result_builder.now_token, + next_batch=SyncNextBatchToken( + stream_token=sync_result_builder.now_token, + pagination_state=sync_result_builder.pagination_state, + ), + pagination_info=sync_result_builder.pagination_info, )) @defer.inlineCallbacks @@ -646,7 +663,8 @@ class SyncHandler(object): sync_result_builder.presence = presence @defer.inlineCallbacks - def _generate_sync_entry_for_rooms(self, sync_result_builder, account_data_by_room): + def _generate_sync_entry_for_rooms(self, sync_result_builder, account_data_by_room, + extras): """Generates the rooms portion of the sync response. Populates the `sync_result_builder` with the result. @@ -659,6 +677,7 @@ class SyncHandler(object): `(newly_joined_rooms, newly_joined_users)` """ user_id = sync_result_builder.sync_config.user.to_string() + sync_config = sync_result_builder.sync_config now_token, ephemeral_by_room = yield self.ephemeral_by_room( sync_result_builder.sync_config, @@ -690,6 +709,94 @@ class SyncHandler(object): tags_by_room = yield self.store.get_tags_for_user(user_id) + if sync_config.pagination_config: + pagination_config = sync_config.pagination_config + old_pagination_value = 0 + elif sync_result_builder.pagination_state: + pagination_config = SyncPaginationConfig( + order=sync_result_builder.pagination_state.order, + limit=sync_result_builder.pagination_state.limit, + ) + old_pagination_value = sync_result_builder.pagination_state.value + else: + pagination_config = None + old_pagination_value = 0 + + include_map = extras.get("peek", {}) if extras else {} + + if sync_result_builder.pagination_state: + missing_state = yield self._get_rooms_that_need_full_state( + room_ids=[r.room_id for r in room_entries], + sync_config=sync_config, + since_token=sync_result_builder.since_token, + pagination_state=sync_result_builder.pagination_state, + ) + + if missing_state: + for r in room_entries: + if r.room_id in missing_state: + r.full_state = True + if r.room_id in include_map: + r.always_include = True + r.events = None + r.since_token = None + r.upto_token = now_token + + if pagination_config: + room_ids = [r.room_id for r in room_entries] + pagination_limit = pagination_config.limit + + extra_limit = extras.get("paginate", {}).get("limit", 0) if extras else 0 + + room_map = yield self._get_room_timestamps_at_token( + room_ids, sync_result_builder.now_token, sync_config, + pagination_limit + extra_limit + 1, + ) + + limited = False + if room_map: + sorted_list = sorted( + room_map.items(), + key=lambda item: -item[1] + ) + + cutoff_list = sorted_list[:pagination_limit + extra_limit] + + if cutoff_list[pagination_limit:]: + new_room_ids = set(r[0] for r in cutoff_list[pagination_limit:]) + for r in room_entries: + if r.room_id in new_room_ids: + r.full_state = True + r.always_include = True + r.since_token = None + r.upto_token = now_token + r.events = None + + _, bottom_ts = cutoff_list[-1] + value = bottom_ts + + limited = any( + old_pagination_value < r[1] < value + for r in sorted_list[pagination_limit + extra_limit:] + ) + + sync_result_builder.pagination_state = SyncPaginationState( + order=pagination_config.order, value=value, + limit=pagination_limit + extra_limit, + ) + + sync_result_builder.pagination_info["limited"] = limited + + if len(room_map) == len(room_entries): + sync_result_builder.pagination_state = None + + room_entries = [ + r for r in room_entries + if r.room_id in room_map or r.always_include + ] + + sync_result_builder.full_state |= sync_result_builder.since_token is None + def handle_room_entries(room_entry): return self._generate_room_entry( sync_result_builder, @@ -698,7 +805,6 @@ class SyncHandler(object): ephemeral=ephemeral_by_room.get(room_entry.room_id, []), tags=tags_by_room.get(room_entry.room_id), account_data=account_data_by_room.get(room_entry.room_id, {}), - always_include=sync_result_builder.full_state, ) yield concurrently_execute(handle_room_entries, room_entries, 10) @@ -929,8 +1035,7 @@ class SyncHandler(object): @defer.inlineCallbacks def _generate_room_entry(self, sync_result_builder, ignored_users, - room_builder, ephemeral, tags, account_data, - always_include=False): + room_builder, ephemeral, tags, account_data): """Populates the `joined` and `archived` section of `sync_result_builder` based on the `room_builder`. @@ -946,6 +1051,11 @@ class SyncHandler(object): even if empty. """ newly_joined = room_builder.newly_joined + always_include = ( + newly_joined + or sync_result_builder.full_state + or room_builder.always_include + ) full_state = ( room_builder.full_state or newly_joined @@ -954,11 +1064,10 @@ class SyncHandler(object): events = room_builder.events # We want to shortcut out as early as possible. - if not (always_include or account_data or ephemeral or full_state): + if not (always_include or account_data or ephemeral): if events == [] and tags is None: return - since_token = sync_result_builder.since_token now_token = sync_result_builder.now_token sync_config = sync_result_builder.sync_config @@ -993,7 +1102,7 @@ class SyncHandler(object): ephemeral = sync_config.filter_collection.filter_room_ephemeral(ephemeral) - if not (always_include or batch or account_data or ephemeral or full_state): + if not (always_include or batch or account_data or ephemeral): return state = yield self.compute_state_delta( @@ -1034,6 +1143,82 @@ class SyncHandler(object): else: raise Exception("Unrecognized rtype: %r", room_builder.rtype) + @defer.inlineCallbacks + def _get_room_timestamps_at_token(self, room_ids, token, sync_config, limit): + room_to_entries = {} + + @defer.inlineCallbacks + def _get_last_ts(room_id): + entry = yield self.store.get_last_event_id_ts_for_room( + room_id, token.room_key + ) + + # TODO: Is this ever possible? + room_to_entries[room_id] = entry if entry else { + "origin_server_ts": 0, + } + + yield concurrently_execute(_get_last_ts, room_ids, 10) + + if len(room_to_entries) <= limit: + defer.returnValue({ + room_id: entry["origin_server_ts"] + for room_id, entry in room_to_entries.items() + }) + + queued_events = sorted( + room_to_entries.items(), + key=lambda e: -e[1]["origin_server_ts"] + ) + + to_return = {} + + while len(to_return) < limit and len(queued_events) > 0: + to_fetch = queued_events[:limit - len(to_return)] + event_to_q = { + e["event_id"]: (room_id, e) for room_id, e in to_fetch + if "event_id" in e + } + + # Now we fetch each event to check if its been filtered out + event_map = yield self.store.get_events(event_to_q.keys()) + + recents = sync_config.filter_collection.filter_room_timeline( + event_map.values() + ) + recents = yield filter_events_for_client( + self.store, + sync_config.user.to_string(), + recents, + ) + + to_return.update({r.room_id: r.origin_server_ts for r in recents}) + + for ev_id in set(event_map.keys()) - set(r.event_id for r in recents): + queued_events.append(event_to_q[ev_id]) + + # FIXME: Need to refetch TS + queued_events.sort(key=lambda e: -e[1]["origin_server_ts"]) + + defer.returnValue(to_return) + + @defer.inlineCallbacks + def _get_rooms_that_need_full_state(self, room_ids, sync_config, since_token, + pagination_state): + start_ts = yield self._get_room_timestamps_at_token( + room_ids, since_token, + sync_config=sync_config, + limit=len(room_ids), + ) + + missing_list = frozenset( + room_id for room_id, ts in + sorted(start_ts.items(), key=lambda item: -item[1]) + if ts < pagination_state.value + ) + + defer.returnValue(missing_list) + def _action_has_highlight(actions): for action in actions: @@ -1085,17 +1270,26 @@ def _calculate_state(timeline_contains, timeline_start, previous, current): class SyncResultBuilder(object): "Used to help build up a new SyncResult for a user" - def __init__(self, sync_config, full_state, since_token, now_token): + + __slots__ = ( + "sync_config", "full_state", "batch_token", "since_token", "pagination_state", + "now_token", "presence", "account_data", "joined", "invited", "archived", + "pagination_info", + ) + + def __init__(self, sync_config, full_state, batch_token, now_token): """ Args: sync_config(SyncConfig) full_state(bool): The full_state flag as specified by user - since_token(StreamToken): The token supplied by user, or None. + batch_token(SyncNextBatchToken): The token supplied by user, or None. now_token(StreamToken): The token to sync up to. """ self.sync_config = sync_config self.full_state = full_state - self.since_token = since_token + self.batch_token = batch_token + self.since_token = batch_token.stream_token if batch_token else None + self.pagination_state = batch_token.pagination_state if batch_token else None self.now_token = now_token self.presence = [] @@ -1104,11 +1298,19 @@ class SyncResultBuilder(object): self.invited = [] self.archived = [] + self.pagination_info = {} + class RoomSyncResultBuilder(object): """Stores information needed to create either a `JoinedSyncResult` or `ArchivedSyncResult`. """ + + __slots__ = ( + "room_id", "rtype", "events", "newly_joined", "full_state", "since_token", + "upto_token", "always_include", + ) + def __init__(self, room_id, rtype, events, newly_joined, full_state, since_token, upto_token): """ @@ -1129,3 +1331,4 @@ class RoomSyncResultBuilder(object): self.full_state = full_state self.since_token = since_token self.upto_token = upto_token + self.always_include = False diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 43d8e0bf39..24587cbc61 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -16,10 +16,11 @@ from twisted.internet import defer from synapse.http.servlet import ( - RestServlet, parse_string, parse_integer, parse_boolean + RestServlet, parse_string, parse_integer, parse_boolean, + parse_json_object_from_request, ) -from synapse.handlers.sync import SyncConfig -from synapse.types import StreamToken +from synapse.handlers.sync import SyncConfig, SyncPaginationConfig +from synapse.types import SyncNextBatchToken from synapse.events.utils import ( serialize_event, format_event_for_client_v2_without_room_id, ) @@ -85,6 +86,90 @@ class SyncRestServlet(RestServlet): self.presence_handler = hs.get_presence_handler() @defer.inlineCallbacks + def on_POST(self, request): + requester = yield self.auth.get_user_by_req( + request, allow_guest=True + ) + user = requester.user + + body = parse_json_object_from_request(request) + + timeout = body.get("timeout", 0) + + since = body.get("since", None) + + extras = body.get("extras", None) + + if "from" in body: + # /events used to use 'from', but /sync uses 'since'. + # Lets be helpful and whine if we see a 'from'. + raise SynapseError( + 400, "'from' is not a valid parameter. Did you mean 'since'?" + ) + + set_presence = body.get("set_presence", "online") + if set_presence not in self.ALLOWED_PRESENCE: + message = "Parameter 'set_presence' must be one of [%s]" % ( + ", ".join(repr(v) for v in self.ALLOWED_PRESENCE) + ) + raise SynapseError(400, message) + + full_state = body.get("full_state", False) + + filter_id = body.get("filter_id", None) + filter_dict = body.get("filter", None) + pagination_config = body.get("pagination_config", None) + + if filter_dict is not None and filter_id is not None: + raise SynapseError( + 400, + "Can only specify one of `filter` and `filter_id` paramters" + ) + + if filter_id: + filter_collection = yield self.filtering.get_user_filter( + user.localpart, filter_id + ) + filter_key = filter_id + elif filter_dict: + self.filtering.check_valid_filter(filter_dict) + filter_collection = FilterCollection(filter_dict) + filter_key = json.dumps(filter_dict) + else: + filter_collection = DEFAULT_FILTER_COLLECTION + filter_key = None + + request_key = (user, timeout, since, filter_key, full_state) + + sync_config = SyncConfig( + user=user, + filter_collection=filter_collection, + is_guest=requester.is_guest, + request_key=request_key, + pagination_config=SyncPaginationConfig( + order=pagination_config["order"], + limit=pagination_config["limit"], + ) if pagination_config else None, + ) + + if since is not None: + batch_token = SyncNextBatchToken.from_string(since) + else: + batch_token = None + + sync_result = yield self._handle_sync( + requester=requester, + sync_config=sync_config, + batch_token=batch_token, + set_presence=set_presence, + full_state=full_state, + timeout=timeout, + extras=extras, + ) + + defer.returnValue(sync_result) + + @defer.inlineCallbacks def on_GET(self, request): if "from" in request.args: # /events used to use 'from', but /sync uses 'since'. @@ -136,15 +221,32 @@ class SyncRestServlet(RestServlet): filter_collection=filter, is_guest=requester.is_guest, request_key=request_key, + pagination_config=None, ) if since is not None: - since_token = StreamToken.from_string(since) + batch_token = SyncNextBatchToken.from_string(since) else: - since_token = None + batch_token = None + + sync_result = yield self._handle_sync( + requester=requester, + sync_config=sync_config, + batch_token=batch_token, + set_presence=set_presence, + full_state=full_state, + timeout=timeout, + ) + + defer.returnValue(sync_result) + @defer.inlineCallbacks + def _handle_sync(self, requester, sync_config, batch_token, set_presence, + full_state, timeout, extras=None): affect_presence = set_presence != PresenceState.OFFLINE + user = sync_config.user + if affect_presence: yield self.presence_handler.set_state(user, {"presence": set_presence}) @@ -153,8 +255,8 @@ class SyncRestServlet(RestServlet): ) with context: sync_result = yield self.sync_handler.wait_for_sync_for_user( - sync_config, since_token=since_token, timeout=timeout, - full_state=full_state + sync_config, batch_token=batch_token, timeout=timeout, + full_state=full_state, extras=extras, ) time_now = self.clock.time_msec() @@ -184,6 +286,9 @@ class SyncRestServlet(RestServlet): "next_batch": sync_result.next_batch.to_string(), } + if sync_result.pagination_info: + response_content["pagination_info"] = sync_result.pagination_info + defer.returnValue((200, response_content)) def encode_presence(self, events, time_now): diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index b9ad965fd6..ada20706dc 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -525,6 +525,26 @@ class StreamStore(SQLBaseStore): int(stream), ) + def get_last_event_id_ts_for_room(self, room_id, token): + stream_ordering = RoomStreamToken.parse_stream_token(token).stream + + sql = ( + "SELECT event_id, origin_server_ts FROM events" + " WHERE room_id = ? AND stream_ordering <= ?" + " ORDER BY topological_ordering DESC, stream_ordering DESC" + " LIMIT 1" + ) + + def f(txn): + txn.execute(sql, (room_id, stream_ordering)) + rows = self.cursor_to_dict(txn) + if rows: + return rows[0] + else: + return None + + return self.runInteraction("get_last_event_id_ts_for_room", f) + @defer.inlineCallbacks def get_events_around(self, room_id, event_id, before_limit, after_limit): """Retrieve events and pagination tokens around a given event in a diff --git a/synapse/streams/config.py b/synapse/streams/config.py index 4f089bfb94..49be3c222a 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -14,7 +14,7 @@ # limitations under the License. from synapse.api.errors import SynapseError -from synapse.types import StreamToken +from synapse.types import StreamToken, SyncNextBatchToken import logging @@ -72,14 +72,18 @@ class PaginationConfig(object): if direction not in ['f', 'b']: raise SynapseError(400, "'dir' parameter is invalid.") - from_tok = get_param("from") + raw_from_tok = get_param("from") to_tok = get_param("to") try: - if from_tok == "END": + from_tok = None + if raw_from_tok == "END": from_tok = None # For backwards compat. - elif from_tok: - from_tok = StreamToken.from_string(from_tok) + elif raw_from_tok: + try: + from_tok = SyncNextBatchToken.from_string(raw_from_tok).stream_token + except: + from_tok = StreamToken.from_string(raw_from_tok) except: raise SynapseError(400, "'from' paramater is invalid") diff --git a/synapse/types.py b/synapse/types.py index f639651a73..13cdc737fb 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -17,6 +17,9 @@ from synapse.api.errors import SynapseError from collections import namedtuple +from unpaddedbase64 import encode_base64, decode_base64 +import ujson as json + Requester = namedtuple("Requester", ["user", "access_token_id", "is_guest"]) @@ -115,8 +118,63 @@ class EventID(DomainSpecificString): SIGIL = "$" +class SyncNextBatchToken( + namedtuple("SyncNextBatchToken", ( + "stream_token", + "pagination_state", + )) +): + @classmethod + def from_string(cls, string): + try: + d = json.loads(decode_base64(string)) + pa = d.get("pa", None) + if pa: + pa = SyncPaginationState.from_dict(pa) + return cls( + stream_token=StreamToken.from_arr(d["t"]), + pagination_state=pa, + ) + except: + raise SynapseError(400, "Invalid Token") + + def to_string(self): + return encode_base64(json.dumps({ + "t": self.stream_token.to_arr(), + "pa": self.pagination_state.to_dict() if self.pagination_state else None, + })) + + def replace(self, **kwargs): + return self._replace(**kwargs) + + +class SyncPaginationState( + namedtuple("SyncPaginationState", ( + "order", + "value", + "limit", + )) +): + @classmethod + def from_dict(cls, d): + try: + return cls(d["o"], d["v"], d["l"]) + except: + raise SynapseError(400, "Invalid Token") + + def to_dict(self): + return { + "o": self.order, + "v": self.value, + "l": self.limit, + } + + def replace(self, **kwargs): + return self._replace(**kwargs) + + class StreamToken( - namedtuple("Token", ( + namedtuple("StreamToken", ( "room_key", "presence_key", "typing_key", @@ -141,6 +199,20 @@ class StreamToken( def to_string(self): return self._SEPARATOR.join([str(k) for k in self]) + @classmethod + def from_arr(cls, arr): + try: + keys = arr + while len(keys) < len(cls._fields): + # i.e. old token from before receipt_key + keys.append("0") + return cls(*keys) + except: + raise SynapseError(400, "Invalid Token") + + def to_arr(self): + return self + @property def room_stream_id(self): # TODO(markjh): Awful hack to work around hacks in the presence tests |