summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/_base.py8
-rw-r--r--synapse/handlers/federation.py6
-rw-r--r--synapse/handlers/register.py4
-rw-r--r--synapse/push/action_generator.py6
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py27
-rw-r--r--synapse/rest/client/v2_alpha/register.py5
-rw-r--r--synapse/storage/event_push_actions.py4
-rw-r--r--synapse/storage/registration.py40
-rw-r--r--synapse/storage/schema/delta/28/event_push_actions.sql (renamed from synapse/storage/schema/delta/27/event_push_actions.sql)0
9 files changed, 69 insertions, 31 deletions
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 3115a5065d..66e35de6e4 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -23,8 +23,6 @@ from synapse.push.action_generator import ActionGenerator
 
 from synapse.util.logcontext import PreserveLoggingContext
 
-from synapse.events.utils import serialize_event
-
 import logging
 
 
@@ -256,9 +254,9 @@ class BaseHandler(object):
         )
 
         action_generator = ActionGenerator(self.store)
-        yield action_generator.handle_push_actions_for_event(serialize_event(
-            event, self.clock.time_msec()
-        ))
+        yield action_generator.handle_push_actions_for_event(
+            event, self
+        )
 
         destinations = set(extra_destinations)
         for k, s in context.current_state.items():
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 764709b424..075b9e21c3 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -32,7 +32,7 @@ from synapse.crypto.event_signing import (
 )
 from synapse.types import UserID
 
-from synapse.events.utils import prune_event, serialize_event
+from synapse.events.utils import prune_event
 
 from synapse.util.retryutils import NotRetryingDestination
 
@@ -246,8 +246,8 @@ class FederationHandler(BaseHandler):
 
         if not backfilled and not event.internal_metadata.is_outlier():
             action_generator = ActionGenerator(self.store)
-            yield action_generator.handle_push_actions_for_event(serialize_event(
-                event, self.clock.time_msec())
+            yield action_generator.handle_push_actions_for_event(
+                event, self
             )
 
     @defer.inlineCallbacks
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 6f111ff63e..1799a668c6 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -84,7 +84,8 @@ class RegistrationHandler(BaseHandler):
         localpart=None,
         password=None,
         generate_token=True,
-        guest_access_token=None
+        guest_access_token=None,
+        make_guest=False
     ):
         """Registers a new client on the server.
 
@@ -118,6 +119,7 @@ class RegistrationHandler(BaseHandler):
                 token=token,
                 password_hash=password_hash,
                 was_guest=guest_access_token is not None,
+                make_guest=make_guest
             )
 
             yield registered_user(self.distributor, user)
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index 5526324a6d..bcd40798f9 100644
--- a/synapse/push/action_generator.py
+++ b/synapse/push/action_generator.py
@@ -33,12 +33,12 @@ class ActionGenerator:
         # tag (ie. we just need all the users).
 
     @defer.inlineCallbacks
-    def handle_push_actions_for_event(self, event):
+    def handle_push_actions_for_event(self, event, handler):
         bulk_evaluator = yield bulk_push_rule_evaluator.evaluator_for_room_id(
-            event['room_id'], self.store
+            event.room_id, self.store
         )
 
-        actions_by_user = bulk_evaluator.action_for_event_by_user(event)
+        actions_by_user = yield bulk_evaluator.action_for_event_by_user(event, handler)
 
         yield self.store.set_push_actions_for_event_and_users(
             event,
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index c00acfd87e..63d65b4465 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -23,6 +23,8 @@ from synapse.types import UserID
 import baserules
 from push_rule_evaluator import PushRuleEvaluator
 
+from synapse.events.utils import serialize_event
+
 logger = logging.getLogger(__name__)
 
 
@@ -54,7 +56,7 @@ def evaluator_for_room_id(room_id, store):
             display_names[ev.state_key] = ev.content.get("displayname")
 
     defer.returnValue(BulkPushRuleEvaluator(
-        room_id, rules_by_user, display_names, users
+        room_id, rules_by_user, display_names, users, store
     ))
 
 
@@ -67,13 +69,15 @@ class BulkPushRuleEvaluator:
     the same logic to run the actual rules, but could be optimised further
     (see https://matrix.org/jira/browse/SYN-562)
     """
-    def __init__(self, room_id, rules_by_user, display_names, users_in_room):
+    def __init__(self, room_id, rules_by_user, display_names, users_in_room, store):
         self.room_id = room_id
         self.rules_by_user = rules_by_user
         self.display_names = display_names
         self.users_in_room = users_in_room
+        self.store = store
 
-    def action_for_event_by_user(self, event):
+    @defer.inlineCallbacks
+    def action_for_event_by_user(self, event, handler):
         actions_by_user = {}
 
         for uid, rules in self.rules_by_user.items():
@@ -81,6 +85,13 @@ class BulkPushRuleEvaluator:
             if uid in self.display_names:
                 display_name = self.display_names[uid]
 
+            is_guest = yield self.store.is_guest(UserID.from_string(uid))
+            filtered = yield handler._filter_events_for_client(
+                uid, [event], is_guest=is_guest
+            )
+            if len(filtered) == 0:
+                continue
+
             for rule in rules:
                 if 'enabled' in rule and not rule['enabled']:
                     continue
@@ -94,14 +105,20 @@ class BulkPushRuleEvaluator:
                     if len(actions) > 0:
                         actions_by_user[uid] = actions
                     break
-        return actions_by_user
+        defer.returnValue(actions_by_user)
 
     @staticmethod
     def event_matches_rule(event, rule,
                            display_name, room_member_count, profile_tag):
         matches = True
+
+        # passing the clock all the way into here is extremely awkward and push
+        # rules do not care about any of the relative timestamps, so we just
+        # pass 0 for the current time.
+        client_event = serialize_event(event, 0)
+
         for cond in rule['conditions']:
             matches &= PushRuleEvaluator._event_fulfills_condition(
-                event, cond, display_name, room_member_count, profile_tag
+                client_event, cond, display_name, room_member_count, profile_tag
             )
         return matches
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 25389ceded..c4d025b465 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -259,7 +259,10 @@ class RegisterRestServlet(RestServlet):
     def _do_guest_registration(self):
         if not self.hs.config.allow_guest_access:
             defer.returnValue((403, "Guest access is disabled"))
-        user_id, _ = yield self.registration_handler.register(generate_token=False)
+        user_id, _ = yield self.registration_handler.register(
+            generate_token=False,
+            make_guest=True
+        )
         access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"])
         defer.returnValue((200, {
             "user_id": user_id,
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index 3075d02257..0634af6b62 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -32,8 +32,8 @@ class EventPushActionsStore(SQLBaseStore):
         values = []
         for uid, profile_tag, actions in tuples:
             values.append({
-                'room_id': event['room_id'],
-                'event_id': event['event_id'],
+                'room_id': event.room_id,
+                'event_id': event.event_id,
                 'user_id': uid,
                 'profile_tag': profile_tag,
                 'actions': json.dumps(actions)
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index f0fa0bd33c..c79066f774 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -18,7 +18,7 @@ from twisted.internet import defer
 from synapse.api.errors import StoreError, Codes
 
 from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
 
 class RegistrationStore(SQLBaseStore):
@@ -73,7 +73,8 @@ class RegistrationStore(SQLBaseStore):
         )
 
     @defer.inlineCallbacks
-    def register(self, user_id, token, password_hash, was_guest=False):
+    def register(self, user_id, token, password_hash,
+                 was_guest=False, make_guest=False):
         """Attempts to register an account.
 
         Args:
@@ -82,15 +83,18 @@ class RegistrationStore(SQLBaseStore):
             password_hash (str): Optional. The password hash for this user.
             was_guest (bool): Optional. Whether this is a guest account being
                 upgraded to a non-guest account.
+            make_guest (boolean): True if the the new user should be guest,
+                false to add a regular user account.
         Raises:
             StoreError if the user_id could not be registered.
         """
         yield self.runInteraction(
             "register",
-            self._register, user_id, token, password_hash, was_guest
+            self._register, user_id, token, password_hash, was_guest, make_guest
         )
+        self.is_guest.invalidate((user_id,))
 
-    def _register(self, txn, user_id, token, password_hash, was_guest):
+    def _register(self, txn, user_id, token, password_hash, was_guest, make_guest):
         now = int(self.clock.time())
 
         next_id = self._access_tokens_id_gen.get_next_txn(txn)
@@ -100,12 +104,14 @@ class RegistrationStore(SQLBaseStore):
                 txn.execute("UPDATE users SET"
                             " password_hash = ?,"
                             " upgrade_ts = ?"
+                            " is_guest = ?"
                             " WHERE name = ?",
-                            [password_hash, now, user_id])
+                            [password_hash, now, make_guest, user_id])
             else:
-                txn.execute("INSERT INTO users(name, password_hash, creation_ts) "
-                            "VALUES (?,?,?)",
-                            [user_id, password_hash, now])
+                txn.execute("INSERT INTO users "
+                            "(name, password_hash, creation_ts, is_guest) "
+                            "VALUES (?,?,?,?)",
+                            [user_id, password_hash, now, make_guest])
         except self.database_engine.module.IntegrityError:
             raise StoreError(
                 400, "User ID already taken.", errcode=Codes.USER_IN_USE
@@ -126,7 +132,7 @@ class RegistrationStore(SQLBaseStore):
             keyvalues={
                 "name": user_id,
             },
-            retcols=["name", "password_hash"],
+            retcols=["name", "password_hash", "is_guest"],
             allow_none=True,
         )
 
@@ -136,7 +142,7 @@ class RegistrationStore(SQLBaseStore):
         """
         def f(txn):
             sql = (
-                "SELECT name, password_hash FROM users"
+                "SELECT name, password_hash, is_guest FROM users"
                 " WHERE lower(name) = lower(?)"
             )
             txn.execute(sql, (user_id,))
@@ -249,9 +255,21 @@ class RegistrationStore(SQLBaseStore):
 
         defer.returnValue(res if res else False)
 
+    @cachedInlineCallbacks()
+    def is_guest(self, user):
+        res = yield self._simple_select_one_onecol(
+            table="users",
+            keyvalues={"name": user.to_string()},
+            retcol="is_guest",
+            allow_none=True,
+            desc="is_guest",
+        )
+
+        defer.returnValue(res if res else False)
+
     def _query_for_auth(self, txn, token):
         sql = (
-            "SELECT users.name, access_tokens.id as token_id"
+            "SELECT users.name, users.is_guest, access_tokens.id as token_id"
             " FROM users"
             " INNER JOIN access_tokens on users.name = access_tokens.user_id"
             " WHERE token = ?"
diff --git a/synapse/storage/schema/delta/27/event_push_actions.sql b/synapse/storage/schema/delta/28/event_push_actions.sql
index bdf6ae3f24..bdf6ae3f24 100644
--- a/synapse/storage/schema/delta/27/event_push_actions.sql
+++ b/synapse/storage/schema/delta/28/event_push_actions.sql