summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/6298.misc1
-rw-r--r--changelog.d/6301.feature1
-rw-r--r--synapse/api/constants.py7
-rw-r--r--synapse/api/filtering.py15
-rw-r--r--synapse/events/snapshot.py107
-rw-r--r--synapse/handlers/federation.py38
-rw-r--r--synapse/rest/client/versions.py3
-rw-r--r--synapse/storage/data_stores/main/events.py36
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/event_labels.sql30
-rw-r--r--synapse/storage/data_stores/main/stream.py11
-rw-r--r--tests/api/test_filtering.py43
-rw-r--r--tests/rest/client/v1/test_rooms.py101
-rw-r--r--tests/rest/client/v1/utils.py15
-rw-r--r--tests/rest/client/v2_alpha/test_sync.py143
-rw-r--r--tests/test_federation.py4
15 files changed, 461 insertions, 94 deletions
diff --git a/changelog.d/6298.misc b/changelog.d/6298.misc
new file mode 100644

index 0000000000..d4190730b2 --- /dev/null +++ b/changelog.d/6298.misc
@@ -0,0 +1 @@ +Refactor EventContext for clarity. \ No newline at end of file diff --git a/changelog.d/6301.feature b/changelog.d/6301.feature new file mode 100644
index 0000000000..78a187a1dc --- /dev/null +++ b/changelog.d/6301.feature
@@ -0,0 +1 @@ +Implement label-based filtering on `/sync` and `/messages` ([MSC2326](https://github.com/matrix-org/matrix-doc/pull/2326)). diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 50ddff2934..f4f1ac27c0 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py
@@ -141,3 +141,10 @@ class LimitBlockingTypes(object): MONTHLY_ACTIVE_USER = "monthly_active_user" HS_DISABLED = "hs_disabled" + + +class EventContentFields(object): + """Fields found in events' content, regardless of type.""" + + # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326 + LABELS = "org.matrix.labels" diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 9f06556bd2..bec13f08d8 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py
@@ -20,6 +20,7 @@ from jsonschema import FormatChecker from twisted.internet import defer +from synapse.api.constants import EventContentFields from synapse.api.errors import SynapseError from synapse.storage.presence import UserPresenceState from synapse.types import RoomID, UserID @@ -66,6 +67,10 @@ ROOM_EVENT_FILTER_SCHEMA = { "contains_url": {"type": "boolean"}, "lazy_load_members": {"type": "boolean"}, "include_redundant_members": {"type": "boolean"}, + # Include or exclude events with the provided labels. + # cf https://github.com/matrix-org/matrix-doc/pull/2326 + "org.matrix.labels": {"type": "array", "items": {"type": "string"}}, + "org.matrix.not_labels": {"type": "array", "items": {"type": "string"}}, }, } @@ -259,6 +264,9 @@ class Filter(object): self.contains_url = self.filter_json.get("contains_url", None) + self.labels = self.filter_json.get("org.matrix.labels", None) + self.not_labels = self.filter_json.get("org.matrix.not_labels", []) + def filters_all_types(self): return "*" in self.not_types @@ -282,6 +290,7 @@ class Filter(object): room_id = None ev_type = "m.presence" contains_url = False + labels = [] else: sender = event.get("sender", None) if not sender: @@ -300,10 +309,11 @@ class Filter(object): content = event.get("content", {}) # check if there is a string url field in the content for filtering purposes contains_url = isinstance(content.get("url"), text_type) + labels = content.get(EventContentFields.LABELS, []) - return self.check_fields(room_id, sender, ev_type, contains_url) + return self.check_fields(room_id, sender, ev_type, labels, contains_url) - def check_fields(self, room_id, sender, event_type, contains_url): + def check_fields(self, room_id, sender, event_type, labels, contains_url): """Checks whether the filter matches the given event fields. Returns: @@ -313,6 +323,7 @@ class Filter(object): "rooms": lambda v: room_id == v, "senders": lambda v: sender == v, "types": lambda v: _matches_wildcard(event_type, v), + "labels": lambda v: v in labels, } for name, match_func in literal_keys.items(): diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 27cd8a63ff..a269de5482 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py
@@ -37,9 +37,6 @@ class EventContext: delta_ids (dict[(str, str), str]): Delta from ``prev_group``. (type, state_key) -> event_id. ``None`` for an outlier. - prev_state_events (?): XXX: is this ever set to anything other than - the empty list? - app_service: FIXME _current_state_ids (dict[(str, str), str]|None): @@ -51,36 +48,16 @@ class EventContext: The current state map excluding the current event. None if outlier or we haven't fetched the state from DB yet. (type, state_key) -> event_id - - _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have - been calculated. None if we haven't started calculating yet - - _event_type (str): The type of the event the context is associated with. - Only set when state has not been fetched yet. - - _event_state_key (str|None): The state_key of the event the context is - associated with. Only set when state has not been fetched yet. - - _prev_state_id (str|None): If the event associated with the context is - a state event, then `_prev_state_id` is the event_id of the state - that was replaced. - Only set when state has not been fetched yet. """ state_group = attr.ib(default=None) rejected = attr.ib(default=False) prev_group = attr.ib(default=None) delta_ids = attr.ib(default=None) - prev_state_events = attr.ib(default=attr.Factory(list)) app_service = attr.ib(default=None) - _current_state_ids = attr.ib(default=None) _prev_state_ids = attr.ib(default=None) - _prev_state_id = attr.ib(default=None) - - _event_type = attr.ib(default=None) - _event_state_key = attr.ib(default=None) - _fetching_state_deferred = attr.ib(default=None) + _current_state_ids = attr.ib(default=None) @staticmethod def with_state( @@ -90,7 +67,6 @@ class EventContext: current_state_ids=current_state_ids, prev_state_ids=prev_state_ids, state_group=state_group, - fetching_state_deferred=defer.succeed(None), prev_group=prev_group, delta_ids=delta_ids, ) @@ -125,7 +101,6 @@ class EventContext: "rejected": self.rejected, "prev_group": self.prev_group, "delta_ids": _encode_state_dict(self.delta_ids), - "prev_state_events": self.prev_state_events, "app_service_id": self.app_service.id if self.app_service else None, } @@ -141,7 +116,7 @@ class EventContext: Returns: EventContext """ - context = EventContext( + context = _AsyncEventContextImpl( # We use the state_group and prev_state_id stuff to pull the # current_state_ids out of the DB and construct prev_state_ids. prev_state_id=input["prev_state_id"], @@ -151,7 +126,6 @@ class EventContext: prev_group=input["prev_group"], delta_ids=_decode_state_dict(input["delta_ids"]), rejected=input["rejected"], - prev_state_events=input["prev_state_events"], ) app_service_id = input["app_service_id"] @@ -170,14 +144,7 @@ class EventContext: Maps a (type, state_key) to the event ID of the state event matching this tuple. """ - - if not self._fetching_state_deferred: - self._fetching_state_deferred = run_in_background( - self._fill_out_state, store - ) - - yield make_deferred_yieldable(self._fetching_state_deferred) - + yield self._ensure_fetched(store) return self._current_state_ids @defer.inlineCallbacks @@ -190,14 +157,7 @@ class EventContext: Maps a (type, state_key) to the event ID of the state event matching this tuple. """ - - if not self._fetching_state_deferred: - self._fetching_state_deferred = run_in_background( - self._fill_out_state, store - ) - - yield make_deferred_yieldable(self._fetching_state_deferred) - + yield self._ensure_fetched(store) return self._prev_state_ids def get_cached_current_state_ids(self): @@ -211,6 +171,44 @@ class EventContext: return self._current_state_ids + def _ensure_fetched(self, store): + return defer.succeed(None) + + +@attr.s(slots=True) +class _AsyncEventContextImpl(EventContext): + """ + An implementation of EventContext which fetches _current_state_ids and + _prev_state_ids from the database on demand. + + Attributes: + + _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have + been calculated. None if we haven't started calculating yet + + _event_type (str): The type of the event the context is associated with. + + _event_state_key (str): The state_key of the event the context is + associated with. + + _prev_state_id (str|None): If the event associated with the context is + a state event, then `_prev_state_id` is the event_id of the state + that was replaced. + """ + + _prev_state_id = attr.ib(default=None) + _event_type = attr.ib(default=None) + _event_state_key = attr.ib(default=None) + _fetching_state_deferred = attr.ib(default=None) + + def _ensure_fetched(self, store): + if not self._fetching_state_deferred: + self._fetching_state_deferred = run_in_background( + self._fill_out_state, store + ) + + return make_deferred_yieldable(self._fetching_state_deferred) + @defer.inlineCallbacks def _fill_out_state(self, store): """Called to populate the _current_state_ids and _prev_state_ids @@ -228,27 +226,6 @@ class EventContext: else: self._prev_state_ids = self._current_state_ids - @defer.inlineCallbacks - def update_state( - self, state_group, prev_state_ids, current_state_ids, prev_group, delta_ids - ): - """Replace the state in the context - """ - - # We need to make sure we wait for any ongoing fetching of state - # to complete so that the updated state doesn't get clobbered - if self._fetching_state_deferred: - yield make_deferred_yieldable(self._fetching_state_deferred) - - self.state_group = state_group - self._prev_state_ids = prev_state_ids - self.prev_group = prev_group - self._current_state_ids = current_state_ids - self.delta_ids = delta_ids - - # We need to ensure that that we've marked as having fetched the state - self._fetching_state_deferred = defer.succeed(None) - def _encode_state_dict(state_dict): """Since dicts of (type, state_key) -> event_id cannot be serialized in diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 04eebbd51e..8b8d907105 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py
@@ -45,6 +45,7 @@ from synapse.api.errors import ( from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.crypto.event_signing import compute_event_signature from synapse.event_auth import auth_types_for_event +from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.logging.context import ( make_deferred_yieldable, @@ -1878,14 +1879,7 @@ class FederationHandler(BaseHandler): if c and c.type == EventTypes.Create: auth_events[(c.type, c.state_key)] = c - try: - yield self.do_auth(origin, event, context, auth_events=auth_events) - except AuthError as e: - logger.warning( - "[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg - ) - - context.rejected = RejectedReason.AUTH_ERROR + context = yield self.do_auth(origin, event, context, auth_events=auth_events) if not context.rejected: yield self._check_for_soft_fail(event, state, backfilled) @@ -2054,12 +2048,12 @@ class FederationHandler(BaseHandler): Also NB that this function adds entries to it. Returns: - defer.Deferred[None] + defer.Deferred[EventContext]: updated context object """ room_version = yield self.store.get_room_version(event.room_id) try: - yield self._update_auth_events_and_context_for_auth( + context = yield self._update_auth_events_and_context_for_auth( origin, event, context, auth_events ) except Exception: @@ -2077,7 +2071,9 @@ class FederationHandler(BaseHandler): event_auth.check(room_version, event, auth_events=auth_events) except AuthError as e: logger.warning("Failed auth resolution for %r because %s", event, e) - raise e + context.rejected = RejectedReason.AUTH_ERROR + + return context @defer.inlineCallbacks def _update_auth_events_and_context_for_auth( @@ -2101,7 +2097,7 @@ class FederationHandler(BaseHandler): auth_events (dict[(str, str)->synapse.events.EventBase]): Returns: - defer.Deferred[None] + defer.Deferred[EventContext]: updated context """ event_auth_events = set(event.auth_event_ids()) @@ -2140,7 +2136,7 @@ class FederationHandler(BaseHandler): # The other side isn't around or doesn't implement the # endpoint, so lets just bail out. logger.info("Failed to get event auth from remote: %s", e) - return + return context seen_remotes = yield self.store.have_seen_events( [e.event_id for e in remote_auth_chain] @@ -2181,7 +2177,7 @@ class FederationHandler(BaseHandler): if event.internal_metadata.is_outlier(): logger.info("Skipping auth_event fetch for outlier") - return + return context # FIXME: Assumes we have and stored all the state for all the # prev_events @@ -2190,7 +2186,7 @@ class FederationHandler(BaseHandler): ) if not different_auth: - return + return context logger.info( "auth_events refers to events which are not in our calculated auth " @@ -2237,10 +2233,12 @@ class FederationHandler(BaseHandler): auth_events.update(new_state) - yield self._update_context_for_auth_events( + context = yield self._update_context_for_auth_events( event, context, auth_events, event_key ) + return context + @defer.inlineCallbacks def _update_context_for_auth_events(self, event, context, auth_events, event_key): """Update the state_ids in an event context after auth event resolution, @@ -2249,14 +2247,16 @@ class FederationHandler(BaseHandler): Args: event (Event): The event we're handling the context for - context (synapse.events.snapshot.EventContext): event context - to be updated + context (synapse.events.snapshot.EventContext): initial event context auth_events (dict[(str, str)->str]): Events to update in the event context. event_key ((str, str)): (type, state_key) for the current event. this will not be included in the current_state in the context. + + Returns: + Deferred[EventContext]: new event context """ state_updates = { k: a.event_id for k, a in iteritems(auth_events) if k != event_key @@ -2281,7 +2281,7 @@ class FederationHandler(BaseHandler): current_state_ids=current_state_ids, ) - yield context.update_state( + return EventContext.with_state( state_group=state_group, current_state_ids=current_state_ids, prev_state_ids=prev_state_ids, diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 1044ae7b4e..bb30ce3f34 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py
@@ -65,6 +65,9 @@ class VersionsRestServlet(RestServlet): "m.require_identity_server": False, # as per MSC2290 "m.separate_add_and_bind": True, + # Implements support for label-based filtering as described in + # MSC2326. + "org.matrix.label_based_filtering": True, }, }, ) diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 1f140b553a..b332a42d82 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py
@@ -29,7 +29,7 @@ from prometheus_client import Counter from twisted.internet import defer import synapse.metrics -from synapse.api.constants import EventTypes +from synapse.api.constants import EventContentFields, EventTypes from synapse.api.errors import SynapseError from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 @@ -935,6 +935,13 @@ class EventsStore( self._handle_event_relations(txn, event) + # Store the labels for this event. + labels = event.content.get(EventContentFields.LABELS) + if labels: + self.insert_labels_for_event_txn( + txn, event.event_id, labels, event.room_id, event.depth + ) + # Insert into the room_memberships table. self._store_room_members_txn( txn, @@ -1920,6 +1927,33 @@ class EventsStore( get_all_updated_current_state_deltas_txn, ) + def insert_labels_for_event_txn( + self, txn, event_id, labels, room_id, topological_ordering + ): + """Store the mapping between an event's ID and its labels, with one row per + (event_id, label) tuple. + + Args: + txn (LoggingTransaction): The transaction to execute. + event_id (str): The event's ID. + labels (list[str]): A list of text labels. + room_id (str): The ID of the room the event was sent to. + topological_ordering (int): The position of the event in the room's topology. + """ + return self._simple_insert_many_txn( + txn=txn, + table="event_labels", + values=[ + { + "event_id": event_id, + "label": label, + "room_id": room_id, + "topological_ordering": topological_ordering, + } + for label in labels + ], + ) + AllNewEventsResult = namedtuple( "AllNewEventsResult", diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql new file mode 100644
index 0000000000..5e29c1da19 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql
@@ -0,0 +1,30 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- room_id and topoligical_ordering are denormalised from the events table in order to +-- make the index work. +CREATE TABLE IF NOT EXISTS event_labels ( + event_id TEXT, + label TEXT, + room_id TEXT NOT NULL, + topological_ordering BIGINT NOT NULL, + PRIMARY KEY(event_id, label) +); + + +-- This index enables an event pagination looking for a particular label to index the +-- event_labels table first, which is much quicker than scanning the events table and then +-- filtering by label, if the label is rarely used relative to the size of the room. +CREATE INDEX event_labels_room_id_label_idx ON event_labels(room_id, label, topological_ordering); diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py
index 263999dfca..616ef91d4e 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py
@@ -229,6 +229,14 @@ def filter_to_clause(event_filter): clauses.append("contains_url = ?") args.append(event_filter.contains_url) + # We're only applying the "labels" filter on the database query, because applying the + # "not_labels" filter via a SQL query is non-trivial. Instead, we let + # event_filter.check_fields apply it, which is not as efficient but makes the + # implementation simpler. + if event_filter.labels: + clauses.append("(%s)" % " OR ".join("label = ?" for _ in event_filter.labels)) + args.extend(event_filter.labels) + return " AND ".join(clauses), args @@ -864,8 +872,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): args.append(int(limit)) sql = ( - "SELECT event_id, topological_ordering, stream_ordering" + "SELECT DISTINCT event_id, topological_ordering, stream_ordering" " FROM events" + " LEFT JOIN event_labels USING (event_id, room_id, topological_ordering)" " WHERE outlier = ? AND room_id = ? AND %(bounds)s" " ORDER BY topological_ordering %(order)s," " stream_ordering %(order)s LIMIT ?" diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 6ba623de13..2dc5052249 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py
@@ -19,6 +19,7 @@ import jsonschema from twisted.internet import defer +from synapse.api.constants import EventContentFields from synapse.api.errors import SynapseError from synapse.api.filtering import Filter from synapse.events import FrozenEvent @@ -95,6 +96,8 @@ class FilteringTestCase(unittest.TestCase): "types": ["m.room.message"], "not_rooms": ["!726s6s6q:example.com"], "not_senders": ["@spam:example.com"], + "org.matrix.labels": ["#fun"], + "org.matrix.not_labels": ["#work"], }, "ephemeral": { "types": ["m.receipt", "m.typing"], @@ -320,6 +323,46 @@ class FilteringTestCase(unittest.TestCase): ) self.assertFalse(Filter(definition).check(event)) + def test_filter_labels(self): + definition = {"org.matrix.labels": ["#fun"]} + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={EventContentFields.LABELS: ["#fun"]}, + ) + + self.assertTrue(Filter(definition).check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={EventContentFields.LABELS: ["#notfun"]}, + ) + + self.assertFalse(Filter(definition).check(event)) + + def test_filter_not_labels(self): + definition = {"org.matrix.not_labels": ["#fun"]} + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={EventContentFields.LABELS: ["#fun"]}, + ) + + self.assertFalse(Filter(definition).check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={EventContentFields.LABELS: ["#notfun"]}, + ) + + self.assertTrue(Filter(definition).check(event)) + @defer.inlineCallbacks def test_filter_presence_match(self): user_filter_json = {"presence": {"types": ["m.*"]}} diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 2f2ca74611..5e38fd6ced 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py
@@ -24,7 +24,7 @@ from six.moves.urllib import parse as urlparse from twisted.internet import defer import synapse.rest.admin -from synapse.api.constants import Membership +from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.rest.client.v1 import login, profile, room from tests import unittest @@ -811,6 +811,105 @@ class RoomMessageListTestCase(RoomBase): self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) + def test_filter_labels(self): + """Test that we can filter by a label.""" + message_filter = json.dumps( + {"types": [EventTypes.Message], "org.matrix.labels": ["#fun"]} + ) + + events = self._test_filter_labels(message_filter) + + self.assertEqual(len(events), 2, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) + + def test_filter_not_labels(self): + """Test that we can filter by the absence of a label.""" + message_filter = json.dumps( + {"types": [EventTypes.Message], "org.matrix.not_labels": ["#fun"]} + ) + + events = self._test_filter_labels(message_filter) + + self.assertEqual(len(events), 3, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "without label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1]) + self.assertEqual( + events[2]["content"]["body"], "with two wrong labels", events[2] + ) + + def test_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label.""" + sync_filter = json.dumps( + { + "types": [EventTypes.Message], + "org.matrix.labels": ["#work"], + "org.matrix.not_labels": ["#notfun"], + } + ) + + events = self._test_filter_labels(sync_filter) + + self.assertEqual(len(events), 1, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) + + def _test_filter_labels(self, message_filter): + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "without label"}, + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with wrong label", + EventContentFields.LABELS: ["#work"], + }, + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with two wrong labels", + EventContentFields.LABELS: ["#work", "#notfun"], + }, + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + ) + + token = "s0_0_0_0_0_0_0_0_0" + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&filter=%s" + % (self.room_id, token, message_filter), + ) + self.render(request) + + return channel.json_body["chunk"] + class RoomSearchTestCase(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index cdded88b7f..8ea0cb05ea 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py
@@ -106,13 +106,22 @@ class RestHelper(object): self.auth_user_id = temp_id def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200): - if txn_id is None: - txn_id = "m%s" % (str(time.time())) if body is None: body = "body_text_here" - path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id) content = {"msgtype": "m.text", "body": body} + + return self.send_event( + room_id, "m.room.message", content, txn_id, tok, expect_code + ) + + def send_event( + self, room_id, type, content={}, txn_id=None, tok=None, expect_code=200 + ): + if txn_id is None: + txn_id = "m%s" % (str(time.time())) + + path = "/_matrix/client/r0/rooms/%s/send/%s/%s" % (room_id, type, txn_id) if tok: path = path + "?access_token=%s" % tok diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 71895094bd..3283c0e47b 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -12,10 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json from mock import Mock import synapse.rest.admin +from synapse.api.constants import EventContentFields, EventTypes from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import sync @@ -26,7 +28,12 @@ from tests.server import TimedOutException class FilterTestCase(unittest.HomeserverTestCase): user_id = "@apple:test" - servlets = [sync.register_servlets] + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + sync.register_servlets, + ] def make_homeserver(self, reactor, clock): @@ -70,6 +77,140 @@ class FilterTestCase(unittest.HomeserverTestCase): ) +class SyncFilterTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + sync.register_servlets, + ] + + def test_sync_filter_labels(self): + """Test that we can filter by a label.""" + sync_filter = json.dumps( + { + "room": { + "timeline": { + "types": [EventTypes.Message], + "org.matrix.labels": ["#fun"], + } + } + } + ) + + events = self._test_sync_filter_labels(sync_filter) + + self.assertEqual(len(events), 2, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) + + def test_sync_filter_not_labels(self): + """Test that we can filter by the absence of a label.""" + sync_filter = json.dumps( + { + "room": { + "timeline": { + "types": [EventTypes.Message], + "org.matrix.not_labels": ["#fun"], + } + } + } + ) + + events = self._test_sync_filter_labels(sync_filter) + + self.assertEqual(len(events), 3, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "without label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1]) + self.assertEqual( + events[2]["content"]["body"], "with two wrong labels", events[2] + ) + + def test_sync_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label.""" + sync_filter = json.dumps( + { + "room": { + "timeline": { + "types": [EventTypes.Message], + "org.matrix.labels": ["#work"], + "org.matrix.not_labels": ["#notfun"], + } + } + } + ) + + events = self._test_sync_filter_labels(sync_filter) + + self.assertEqual(len(events), 1, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) + + def _test_sync_filter_labels(self, sync_filter): + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + + room_id = self.helper.create_room_as(user_id, tok=tok) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "without label"}, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with wrong label", + EventContentFields.LABELS: ["#work"], + }, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with two wrong labels", + EventContentFields.LABELS: ["#work", "#notfun"], + }, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=tok, + ) + + request, channel = self.make_request( + "GET", "/sync?filter=%s" % sync_filter, access_token=tok + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + return channel.json_body["rooms"]["join"][room_id]["timeline"]["events"] + + class SyncTypingTests(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/test_federation.py b/tests/test_federation.py
index d1acb16f30..7d82b58466 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py
@@ -59,7 +59,9 @@ class MessageAcceptTests(unittest.TestCase): ) self.handler = self.homeserver.get_handlers().federation_handler - self.handler.do_auth = lambda *a, **b: succeed(True) + self.handler.do_auth = lambda origin, event, context, auth_events: succeed( + context + ) self.client = self.homeserver.get_federation_client() self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed( pdus