diff options
author | Erik Johnston <erik@matrix.org> | 2016-01-25 16:58:39 +0000 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2016-01-25 16:58:39 +0000 |
commit | c887c4cbd5366ecb59b739d298e520414ca12dee (patch) | |
tree | c0913b6f147b4adff5f0c8507f4325d61e321205 | |
parent | Merge pull request #525 from matrix-org/erikj/select_many (diff) | |
parent | PEP8 (diff) | |
download | synapse-c887c4cbd5366ecb59b739d298e520414ca12dee.tar.xz |
Merge pull request #524 from matrix-org/erikj/sync
Move some sync logic from rest to handlers pacakege
Diffstat (limited to '')
-rw-r--r-- | synapse/api/filtering.py | 22 | ||||
-rw-r--r-- | synapse/handlers/sync.py | 182 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/sync.py | 57 | ||||
-rw-r--r-- | tests/api/test_filtering.py | 10 |
4 files changed, 178 insertions, 93 deletions
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 5530b8c48f..6c13ada5df 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -190,18 +190,16 @@ class Filter(object): Returns: bool: True if the event matches """ - if isinstance(event, dict): - return self.check_fields( - event.get("room_id", None), - event.get("sender", None), - event.get("type", None), - ) - else: - return self.check_fields( - getattr(event, "room_id", None), - getattr(event, "sender", None), - event.type, - ) + sender = event.get("sender", None) + if not sender: + # Presence events have their 'sender' in content.user_id + sender = event.get("content", {}).get("user_id", None) + + return self.check_fields( + event.get("room_id", None), + sender, + event.get("type", None), + ) def check_fields(self, room_id, sender, event_type): """Checks whether the filter matches the given event fields. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 53e1eb0508..328c049b03 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) SyncConfig = collections.namedtuple("SyncConfig", [ "user", - "filter", + "filter_collection", "is_guest", ]) @@ -142,8 +142,9 @@ class SyncHandler(BaseHandler): if timeout == 0 or since_token is None or full_state: # we are going to return immediately, so don't bother calling # notifier.wait_for_events. - result = yield self.current_sync_for_user(sync_config, since_token, - full_state=full_state) + result = yield self.current_sync_for_user( + sync_config, since_token, full_state=full_state, + ) defer.returnValue(result) else: def current_sync_callback(before_token, after_token): @@ -151,7 +152,7 @@ class SyncHandler(BaseHandler): result = yield self.notifier.wait_for_events( sync_config.user.to_string(), timeout, current_sync_callback, - from_token=since_token + from_token=since_token, ) defer.returnValue(result) @@ -205,7 +206,7 @@ class SyncHandler(BaseHandler): ) membership_list = (Membership.INVITE, Membership.JOIN) - if sync_config.filter.include_leave: + if sync_config.filter_collection.include_leave: membership_list += (Membership.LEAVE, Membership.BAN) room_list = yield self.store.get_rooms_for_user_where_membership_is( @@ -266,9 +267,17 @@ class SyncHandler(BaseHandler): deferreds, consumeErrors=True ).addErrback(unwrapFirstError) + account_data_for_user = sync_config.filter_collection.filter_account_data( + self.account_data_for_user(account_data) + ) + + presence = sync_config.filter_collection.filter_presence( + presence + ) + defer.returnValue(SyncResult( presence=presence, - account_data=self.account_data_for_user(account_data), + account_data=account_data_for_user, joined=joined, invited=invited, archived=archived, @@ -302,14 +311,31 @@ class SyncHandler(BaseHandler): current_state = yield self.get_state_at(room_id, now_token) + current_state = { + (e.type, e.state_key): e + for e in sync_config.filter_collection.filter_room_state( + current_state.values() + ) + } + + account_data = self.account_data_for_room( + room_id, tags_by_room, account_data_by_room + ) + + account_data = sync_config.filter_collection.filter_room_account_data( + account_data + ) + + ephemeral = sync_config.filter_collection.filter_room_ephemeral( + ephemeral_by_room.get(room_id, []) + ) + defer.returnValue(JoinedSyncResult( room_id=room_id, timeline=batch, state=current_state, - ephemeral=ephemeral_by_room.get(room_id, []), - account_data=self.account_data_for_room( - room_id, tags_by_room, account_data_by_room - ), + ephemeral=ephemeral, + account_data=account_data, unread_notifications=unread_notifications, )) @@ -365,7 +391,7 @@ class SyncHandler(BaseHandler): typing, typing_key = yield typing_source.get_new_events( user=sync_config.user, from_key=typing_key, - limit=sync_config.filter.ephemeral_limit(), + limit=sync_config.filter_collection.ephemeral_limit(), room_ids=room_ids, is_guest=sync_config.is_guest, ) @@ -388,7 +414,7 @@ class SyncHandler(BaseHandler): receipts, receipt_key = yield receipt_source.get_new_events( user=sync_config.user, from_key=receipt_key, - limit=sync_config.filter.ephemeral_limit(), + limit=sync_config.filter_collection.ephemeral_limit(), room_ids=room_ids, is_guest=sync_config.is_guest, ) @@ -419,13 +445,26 @@ class SyncHandler(BaseHandler): leave_state = yield self.store.get_state_for_event(leave_event_id) + leave_state = { + (e.type, e.state_key): e + for e in sync_config.filter_collection.filter_room_state( + leave_state.values() + ) + } + + account_data = self.account_data_for_room( + room_id, tags_by_room, account_data_by_room + ) + + account_data = sync_config.filter_collection.filter_room_account_data( + account_data + ) + defer.returnValue(ArchivedSyncResult( room_id=room_id, timeline=batch, state=leave_state, - account_data=self.account_data_for_room( - room_id, tags_by_room, account_data_by_room - ), + account_data=account_data, )) @defer.inlineCallbacks @@ -444,7 +483,7 @@ class SyncHandler(BaseHandler): presence, presence_key = yield presence_source.get_new_events( user=sync_config.user, from_key=since_token.presence_key, - limit=sync_config.filter.presence_limit(), + limit=sync_config.filter_collection.presence_limit(), room_ids=room_ids, is_guest=sync_config.is_guest, ) @@ -473,7 +512,7 @@ class SyncHandler(BaseHandler): sync_config.user ) - timeline_limit = sync_config.filter.timeline_limit() + timeline_limit = sync_config.filter_collection.timeline_limit() room_events, _ = yield self.store.get_room_events_stream( sync_config.user.to_string(), @@ -538,6 +577,27 @@ class SyncHandler(BaseHandler): # the timeline is inherently limited if we've just joined limited = True + recents = sync_config.filter_collection.filter_room_timeline(recents) + + state = { + (e.type, e.state_key): e + for e in sync_config.filter_collection.filter_room_state( + state.values() + ) + } + + acc_data = self.account_data_for_room( + room_id, tags_by_room, account_data_by_room + ) + + acc_data = sync_config.filter_collection.filter_room_account_data( + acc_data + ) + + ephemeral = sync_config.filter_collection.filter_room_ephemeral( + ephemeral_by_room.get(room_id, []) + ) + room_sync = JoinedSyncResult( room_id=room_id, timeline=TimelineBatch( @@ -546,10 +606,8 @@ class SyncHandler(BaseHandler): limited=limited, ), state=state, - ephemeral=ephemeral_by_room.get(room_id, []), - account_data=self.account_data_for_room( - room_id, tags_by_room, account_data_by_room - ), + ephemeral=ephemeral, + account_data=acc_data, unread_notifications={}, ) logger.debug("Result for room %s: %r", room_id, room_sync) @@ -603,9 +661,17 @@ class SyncHandler(BaseHandler): for event in invite_events ] + account_data_for_user = sync_config.filter_collection.filter_account_data( + self.account_data_for_user(account_data) + ) + + presence = sync_config.filter_collection.filter_presence( + presence + ) + defer.returnValue(SyncResult( presence=presence, - account_data=self.account_data_for_user(account_data), + account_data=account_data_for_user, joined=joined, invited=invited, archived=archived, @@ -621,7 +687,7 @@ class SyncHandler(BaseHandler): limited = True recents = [] filtering_factor = 2 - timeline_limit = sync_config.filter.timeline_limit() + timeline_limit = sync_config.filter_collection.timeline_limit() load_limit = max(timeline_limit * filtering_factor, 100) max_repeat = 3 # Only try a few times per room, otherwise room_key = now_token.room_key @@ -634,9 +700,9 @@ class SyncHandler(BaseHandler): from_token=since_token.room_key if since_token else None, end_token=end_key, ) - (room_key, _) = keys + room_key, _ = keys end_key = "s" + room_key.split('-')[-1] - loaded_recents = sync_config.filter.filter_room_timeline(events) + loaded_recents = sync_config.filter_collection.filter_room_timeline(events) loaded_recents = yield self._filter_events_for_client( sync_config.user.to_string(), loaded_recents, @@ -684,17 +750,23 @@ class SyncHandler(BaseHandler): logger.debug("Recents %r", batch) - current_state = yield self.get_state_at(room_id, now_token) + if batch.limited: + current_state = yield self.get_state_at(room_id, now_token) - state_at_previous_sync = yield self.get_state_at( - room_id, stream_position=since_token - ) + state_at_previous_sync = yield self.get_state_at( + room_id, stream_position=since_token + ) - state = yield self.compute_state_delta( - since_token=since_token, - previous_state=state_at_previous_sync, - current_state=current_state, - ) + state = yield self.compute_state_delta( + since_token=since_token, + previous_state=state_at_previous_sync, + current_state=current_state, + ) + else: + state = { + (event.type, event.state_key): event + for event in batch.events if event.is_state() + } just_joined = yield self.check_joined_room(sync_config, state) if just_joined: @@ -711,14 +783,29 @@ class SyncHandler(BaseHandler): 1 for notif in notifs if _action_has_highlight(notif["actions"]) ]) + state = { + (e.type, e.state_key): e + for e in sync_config.filter_collection.filter_room_state(state.values()) + } + + account_data = self.account_data_for_room( + room_id, tags_by_room, account_data_by_room + ) + + account_data = sync_config.filter_collection.filter_room_account_data( + account_data + ) + + ephemeral = sync_config.filter_collection.filter_room_ephemeral( + ephemeral_by_room.get(room_id, []) + ) + room_sync = JoinedSyncResult( room_id=room_id, timeline=batch, state=state, - ephemeral=ephemeral_by_room.get(room_id, []), - account_data=self.account_data_for_room( - room_id, tags_by_room, account_data_by_room - ), + ephemeral=ephemeral, + account_data=account_data, unread_notifications=unread_notifications, ) @@ -765,13 +852,26 @@ class SyncHandler(BaseHandler): current_state=state_events_at_leave, ) + state_events_delta = { + (e.type, e.state_key): e + for e in sync_config.filter_collection.filter_room_state( + state_events_delta.values() + ) + } + + account_data = self.account_data_for_room( + leave_event.room_id, tags_by_room, account_data_by_room + ) + + account_data = sync_config.filter_collection.filter_room_account_data( + account_data + ) + room_sync = ArchivedSyncResult( room_id=leave_event.room_id, timeline=batch, state=state_events_delta, - account_data=self.account_data_for_room( - leave_event.room_id, tags_by_room, account_data_by_room - ), + account_data=account_data, ) logger.debug("Room sync: %r", room_sync) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index ab924ad9e0..07b5b5dfd5 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -130,7 +130,7 @@ class SyncRestServlet(RestServlet): sync_config = SyncConfig( user=user, - filter=filter, + filter_collection=filter, is_guest=requester.is_guest, ) @@ -154,23 +154,21 @@ class SyncRestServlet(RestServlet): time_now = self.clock.time_msec() joined = self.encode_joined( - sync_result.joined, filter, time_now, requester.access_token_id + sync_result.joined, time_now, requester.access_token_id ) invited = self.encode_invited( - sync_result.invited, filter, time_now, requester.access_token_id + sync_result.invited, time_now, requester.access_token_id ) archived = self.encode_archived( - sync_result.archived, filter, time_now, requester.access_token_id + sync_result.archived, time_now, requester.access_token_id ) response_content = { - "account_data": self.encode_account_data( - sync_result.account_data, filter, time_now - ), + "account_data": {"events": sync_result.account_data}, "presence": self.encode_presence( - sync_result.presence, filter, time_now + sync_result.presence, time_now ), "rooms": { "join": joined, @@ -182,24 +180,20 @@ class SyncRestServlet(RestServlet): defer.returnValue((200, response_content)) - def encode_presence(self, events, filter, time_now): + def encode_presence(self, events, time_now): formatted = [] for event in events: event = copy.deepcopy(event) event['sender'] = event['content'].pop('user_id') formatted.append(event) - return {"events": filter.filter_presence(formatted)} - - def encode_account_data(self, events, filter, time_now): - return {"events": filter.filter_account_data(events)} + return {"events": formatted} - def encode_joined(self, rooms, filter, time_now, token_id): + def encode_joined(self, rooms, time_now, token_id): """ Encode the joined rooms in a sync result :param list[synapse.handlers.sync.JoinedSyncResult] rooms: list of sync results for rooms this user is joined to - :param FilterCollection filter: filters to apply to the results :param int time_now: current time - used as a baseline for age calculations :param int token_id: ID of the user's auth token - used for namespacing @@ -211,18 +205,17 @@ class SyncRestServlet(RestServlet): joined = {} for room in rooms: joined[room.room_id] = self.encode_room( - room, filter, time_now, token_id + room, time_now, token_id ) return joined - def encode_invited(self, rooms, filter, time_now, token_id): + def encode_invited(self, rooms, time_now, token_id): """ Encode the invited rooms in a sync result :param list[synapse.handlers.sync.InvitedSyncResult] rooms: list of sync results for rooms this user is joined to - :param FilterCollection filter: filters to apply to the results :param int time_now: current time - used as a baseline for age calculations :param int token_id: ID of the user's auth token - used for namespacing @@ -237,7 +230,9 @@ class SyncRestServlet(RestServlet): room.invite, time_now, token_id=token_id, event_format=format_event_for_client_v2_without_room_id, ) - invited_state = invite.get("unsigned", {}).pop("invite_room_state", []) + unsigned = dict(invite.get("unsigned", {})) + invite["unsigned"] = unsigned + invited_state = list(unsigned.pop("invite_room_state", [])) invited_state.append(invite) invited[room.room_id] = { "invite_state": {"events": invited_state} @@ -245,13 +240,12 @@ class SyncRestServlet(RestServlet): return invited - def encode_archived(self, rooms, filter, time_now, token_id): + def encode_archived(self, rooms, time_now, token_id): """ Encode the archived rooms in a sync result :param list[synapse.handlers.sync.ArchivedSyncResult] rooms: list of sync results for rooms this user is joined to - :param FilterCollection filter: filters to apply to the results :param int time_now: current time - used as a baseline for age calculations :param int token_id: ID of the user's auth token - used for namespacing @@ -263,17 +257,16 @@ class SyncRestServlet(RestServlet): joined = {} for room in rooms: joined[room.room_id] = self.encode_room( - room, filter, time_now, token_id, joined=False + room, time_now, token_id, joined=False ) return joined @staticmethod - def encode_room(room, filter, time_now, token_id, joined=True): + def encode_room(room, time_now, token_id, joined=True): """ :param JoinedSyncResult|ArchivedSyncResult room: sync result for a single room - :param FilterCollection filter: filters to apply to the results :param int time_now: current time - used as a baseline for age calculations :param int token_id: ID of the user's auth token - used for namespacing @@ -292,19 +285,17 @@ class SyncRestServlet(RestServlet): ) state_dict = room.state - timeline_events = filter.filter_room_timeline(room.timeline.events) + timeline_events = room.timeline.events state_dict = SyncRestServlet._rollback_state_for_timeline( state_dict, timeline_events) - state_events = filter.filter_room_state(state_dict.values()) + state_events = state_dict.values() serialized_state = [serialize(e) for e in state_events] serialized_timeline = [serialize(e) for e in timeline_events] - account_data = filter.filter_room_account_data( - room.account_data - ) + account_data = room.account_data result = { "timeline": { @@ -317,7 +308,7 @@ class SyncRestServlet(RestServlet): } if joined: - ephemeral_events = filter.filter_room_ephemeral(room.ephemeral) + ephemeral_events = room.ephemeral result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications @@ -334,8 +325,6 @@ class SyncRestServlet(RestServlet): :param list[synapse.events.EventBase] timeline: the event timeline :return: updated state dictionary """ - logger.debug("Processing state dict %r; timeline %r", state, - [e.get_dict() for e in timeline]) result = state.copy() @@ -357,8 +346,8 @@ class SyncRestServlet(RestServlet): # the event graph, and the state is no longer valid. Really, # the event shouldn't be in the timeline. We're going to ignore # it for now, however. - logger.warn("Found state event %r in timeline which doesn't " - "match state dictionary", timeline_event) + logger.debug("Found state event %r in timeline which doesn't " + "match state dictionary", timeline_event) continue prev_event_id = timeline_event.unsigned.get("replaces_state", None) diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 16ee6bbe6a..1a4e439d30 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -13,26 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from tests import unittest from twisted.internet import defer -from mock import Mock, NonCallableMock +from mock import Mock from tests.utils import ( MockHttpResource, DeferredMockCallable, setup_test_homeserver ) from synapse.types import UserID -from synapse.api.filtering import FilterCollection, Filter +from synapse.api.filtering import Filter +from synapse.events import FrozenEvent user_localpart = "test_user" # MockEvent = namedtuple("MockEvent", "sender type room_id") def MockEvent(**kwargs): - ev = NonCallableMock(spec_set=kwargs.keys()) - ev.configure_mock(**kwargs) - return ev + return FrozenEvent(kwargs) class FilteringTestCase(unittest.TestCase): |