summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/api/filtering.py41
-rw-r--r--synapse/rest/client/v2_alpha/filter.py25
-rw-r--r--synapse/server.py5
3 files changed, 58 insertions, 13 deletions
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
new file mode 100644
index 0000000000..922c40004c
--- /dev/null
+++ b/synapse/api/filtering.py
@@ -0,0 +1,41 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# TODO(paul)
+_filters_for_user = {}
+
+
+class Filtering(object):
+
+    def __init__(self, hs):
+        super(Filtering, self).__init__()
+        self.hs = hs
+
+    def get_user_filter(self, user_localpart, filter_id):
+        filters = _filters_for_user.get(user_localpart, None)
+
+        if not filters or filter_id >= len(filters):
+            raise KeyError()
+
+        return filters[filter_id]
+
+    def add_user_filter(self, user_localpart, definition):
+        filters = _filters_for_user.setdefault(user_localpart, [])
+
+        filter_id = len(filters)
+        filters.append(definition)
+
+        return filter_id
diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py
index a9a180ec04..585c8e02e8 100644
--- a/synapse/rest/client/v2_alpha/filter.py
+++ b/synapse/rest/client/v2_alpha/filter.py
@@ -28,10 +28,6 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-# TODO(paul)
-_filters_for_user = {}
-
-
 class GetFilterRestServlet(RestServlet):
     PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
 
@@ -39,6 +35,7 @@ class GetFilterRestServlet(RestServlet):
         super(GetFilterRestServlet, self).__init__()
         self.hs = hs
         self.auth = hs.get_auth()
+        self.filtering = hs.get_filtering()
 
     @defer.inlineCallbacks
     def on_GET(self, request, user_id, filter_id):
@@ -56,13 +53,14 @@ class GetFilterRestServlet(RestServlet):
         except:
             raise SynapseError(400, "Invalid filter_id")
 
-        filters = _filters_for_user.get(target_user.localpart, None)
-
-        if not filters or filter_id >= len(filters):
+        try:
+            defer.returnValue((200, self.filtering.get_user_filter(
+                user_localpart=target_user.localpart,
+                filter_id=filter_id,
+            )))
+        except KeyError:
             raise SynapseError(400, "No such filter")
 
-        defer.returnValue((200, filters[filter_id]))
-
 
 class CreateFilterRestServlet(RestServlet):
     PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter")
@@ -71,6 +69,7 @@ class CreateFilterRestServlet(RestServlet):
         super(CreateFilterRestServlet, self).__init__()
         self.hs = hs
         self.auth = hs.get_auth()
+        self.filtering = hs.get_filtering()
 
     @defer.inlineCallbacks
     def on_POST(self, request, user_id):
@@ -90,10 +89,10 @@ class CreateFilterRestServlet(RestServlet):
         except:
             raise SynapseError(400, "Invalid filter definition")
 
-        filters = _filters_for_user.setdefault(target_user.localpart, [])
-
-        filter_id = len(filters)
-        filters.append(content)
+        filter_id = self.filtering.add_user_filter(
+            user_localpart=target_user.localpart,
+            definition=content,
+        )
 
         defer.returnValue((200, {"filter_id": str(filter_id)}))
 
diff --git a/synapse/server.py b/synapse/server.py
index f09d5d581e..9b42079e05 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -32,6 +32,7 @@ from synapse.streams.events import EventSources
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.crypto.keyring import Keyring
 from synapse.events.builder import EventBuilderFactory
+from synapse.api.filtering import Filtering
 
 
 class BaseHomeServer(object):
@@ -79,6 +80,7 @@ class BaseHomeServer(object):
         'ratelimiter',
         'keyring',
         'event_builder_factory',
+        'filtering',
     ]
 
     def __init__(self, hostname, **kwargs):
@@ -197,3 +199,6 @@ class HomeServer(BaseHomeServer):
             clock=self.get_clock(),
             hostname=self.hostname,
         )
+
+    def build_filtering(self):
+        return Filtering(self)