diff options
-rw-r--r-- | synapse/api/filtering.py | 41 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/filter.py | 25 | ||||
-rw-r--r-- | synapse/server.py | 5 |
3 files changed, 58 insertions, 13 deletions
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py new file mode 100644 index 0000000000..922c40004c --- /dev/null +++ b/synapse/api/filtering.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# TODO(paul) +_filters_for_user = {} + + +class Filtering(object): + + def __init__(self, hs): + super(Filtering, self).__init__() + self.hs = hs + + def get_user_filter(self, user_localpart, filter_id): + filters = _filters_for_user.get(user_localpart, None) + + if not filters or filter_id >= len(filters): + raise KeyError() + + return filters[filter_id] + + def add_user_filter(self, user_localpart, definition): + filters = _filters_for_user.setdefault(user_localpart, []) + + filter_id = len(filters) + filters.append(definition) + + return filter_id diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index a9a180ec04..585c8e02e8 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -28,10 +28,6 @@ import logging logger = logging.getLogger(__name__) -# TODO(paul) -_filters_for_user = {} - - class GetFilterRestServlet(RestServlet): PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)") @@ -39,6 +35,7 @@ class GetFilterRestServlet(RestServlet): super(GetFilterRestServlet, self).__init__() self.hs = hs self.auth = hs.get_auth() + self.filtering = hs.get_filtering() @defer.inlineCallbacks def on_GET(self, request, user_id, filter_id): @@ -56,13 +53,14 @@ class GetFilterRestServlet(RestServlet): except: raise SynapseError(400, "Invalid filter_id") - filters = _filters_for_user.get(target_user.localpart, None) - - if not filters or filter_id >= len(filters): + try: + defer.returnValue((200, self.filtering.get_user_filter( + user_localpart=target_user.localpart, + filter_id=filter_id, + ))) + except KeyError: raise SynapseError(400, "No such filter") - defer.returnValue((200, filters[filter_id])) - class CreateFilterRestServlet(RestServlet): PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter") @@ -71,6 +69,7 @@ class CreateFilterRestServlet(RestServlet): super(CreateFilterRestServlet, self).__init__() self.hs = hs self.auth = hs.get_auth() + self.filtering = hs.get_filtering() @defer.inlineCallbacks def on_POST(self, request, user_id): @@ -90,10 +89,10 @@ class CreateFilterRestServlet(RestServlet): except: raise SynapseError(400, "Invalid filter definition") - filters = _filters_for_user.setdefault(target_user.localpart, []) - - filter_id = len(filters) - filters.append(content) + filter_id = self.filtering.add_user_filter( + user_localpart=target_user.localpart, + definition=content, + ) defer.returnValue((200, {"filter_id": str(filter_id)})) diff --git a/synapse/server.py b/synapse/server.py index f09d5d581e..9b42079e05 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -32,6 +32,7 @@ from synapse.streams.events import EventSources from synapse.api.ratelimiting import Ratelimiter from synapse.crypto.keyring import Keyring from synapse.events.builder import EventBuilderFactory +from synapse.api.filtering import Filtering class BaseHomeServer(object): @@ -79,6 +80,7 @@ class BaseHomeServer(object): 'ratelimiter', 'keyring', 'event_builder_factory', + 'filtering', ] def __init__(self, hostname, **kwargs): @@ -197,3 +199,6 @@ class HomeServer(BaseHomeServer): clock=self.get_clock(), hostname=self.hostname, ) + + def build_filtering(self): + return Filtering(self) |