diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 922c40004c..014e2e1fc9 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer
+
# TODO(paul)
_filters_for_user = {}
@@ -24,18 +26,28 @@ class Filtering(object):
super(Filtering, self).__init__()
self.hs = hs
+ @defer.inlineCallbacks
def get_user_filter(self, user_localpart, filter_id):
filters = _filters_for_user.get(user_localpart, None)
if not filters or filter_id >= len(filters):
raise KeyError()
- return filters[filter_id]
+ # trivial yield to make it a generator so d.iC works
+ yield
+ defer.returnValue(filters[filter_id])
+ @defer.inlineCallbacks
def add_user_filter(self, user_localpart, definition):
filters = _filters_for_user.setdefault(user_localpart, [])
filter_id = len(filters)
filters.append(definition)
- return filter_id
+ # trivial yield, see above
+ yield
+ defer.returnValue(filter_id)
+
+ # TODO(paul): surely we should probably add a delete_user_filter or
+ # replace_user_filter at some point? There's no REST API specified for
+ # them however
diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py
index 585c8e02e8..09e44e8ae0 100644
--- a/synapse/rest/client/v2_alpha/filter.py
+++ b/synapse/rest/client/v2_alpha/filter.py
@@ -54,10 +54,12 @@ class GetFilterRestServlet(RestServlet):
raise SynapseError(400, "Invalid filter_id")
try:
- defer.returnValue((200, self.filtering.get_user_filter(
+ filter = yield self.filtering.get_user_filter(
user_localpart=target_user.localpart,
filter_id=filter_id,
- )))
+ )
+
+ defer.returnValue((200, filter))
except KeyError:
raise SynapseError(400, "No such filter")
@@ -89,7 +91,7 @@ class CreateFilterRestServlet(RestServlet):
except:
raise SynapseError(400, "Invalid filter definition")
- filter_id = self.filtering.add_user_filter(
+ filter_id = yield self.filtering.add_user_filter(
user_localpart=target_user.localpart,
definition=content,
)
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index c6c5317696..fecadd1056 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -53,14 +53,15 @@ class FilteringTestCase(unittest.TestCase):
self.filtering = hs.get_filtering()
+ @defer.inlineCallbacks
def test_filter(self):
- filter_id = self.filtering.add_user_filter(
+ filter_id = yield self.filtering.add_user_filter(
user_localpart=user_localpart,
definition={"type": ["m.*"]},
)
self.assertEquals(filter_id, 0)
- filter = self.filtering.get_user_filter(
+ filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart,
filter_id=filter_id,
)
|