summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorAndrew Morgan <1342360+anoadragon453@users.noreply.github.com>2019-03-12 14:42:53 +0000
committerGitHub <noreply@github.com>2019-03-12 14:42:53 +0000
commitd42c81d724477803e6b0db6017281a3394a9cee5 (patch)
treefad76481be9834edbd9bc7b5de5496460b6877db /synapse
parentClarify what registration_shared_secret allows for (#2885) (#4844) (diff)
downloadsynapse-d42c81d724477803e6b0db6017281a3394a9cee5.tar.xz
Transfer local user's push rules on room upgrade (#4838)
Transfer push rules (notifications) on room upgrade
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/room_member.py4
-rw-r--r--synapse/storage/push_rule.py57
2 files changed, 61 insertions, 0 deletions
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 190ea2c7b1..aead9e4608 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -232,6 +232,10 @@ class RoomMemberHandler(object):
                 self.copy_room_tags_and_direct_to_room(
                     predecessor["room_id"], room_id, user_id,
                 )
+                # Move over old push rules
+                self.store.move_push_rules_from_room_to_room_for_user(
+                    predecessor["room_id"], room_id, user_id,
+                )
         elif event.membership == Membership.LEAVE:
             if prev_member_event_id:
                 prev_member_event = yield self.store.get_event(prev_member_event_id)
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 6a5028961d..4b8438c3e9 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -186,6 +186,63 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
         defer.returnValue(results)
 
     @defer.inlineCallbacks
+    def move_push_rule_from_room_to_room(
+        self, new_room_id, user_id, rule,
+    ):
+        """Move a single push rule from one room to another for a specific user.
+
+        Args:
+            new_room_id (str): ID of the new room.
+            user_id (str): ID of user the push rule belongs to.
+            rule (Dict): A push rule.
+        """
+        # Create new rule id
+        rule_id_scope = '/'.join(rule["rule_id"].split('/')[:-1])
+        new_rule_id = rule_id_scope + "/" + new_room_id
+
+        # Change room id in each condition
+        for condition in rule.get("conditions", []):
+            if condition.get("key") == "room_id":
+                condition["pattern"] = new_room_id
+
+        # Add the rule for the new room
+        yield self.add_push_rule(
+            user_id=user_id,
+            rule_id=new_rule_id,
+            priority_class=rule["priority_class"],
+            conditions=rule["conditions"],
+            actions=rule["actions"],
+        )
+
+        # Delete push rule for the old room
+        yield self.delete_push_rule(user_id, rule["rule_id"])
+
+    @defer.inlineCallbacks
+    def move_push_rules_from_room_to_room_for_user(
+        self, old_room_id, new_room_id, user_id,
+    ):
+        """Move all of the push rules from one room to another for a specific
+        user.
+
+        Args:
+            old_room_id (str): ID of the old room.
+            new_room_id (str): ID of the new room.
+            user_id (str): ID of user to copy push rules for.
+        """
+        # Retrieve push rules for this user
+        user_push_rules = yield self.get_push_rules_for_user(user_id)
+
+        # Get rules relating to the old room, move them to the new room, then
+        # delete them from the old room
+        for rule in user_push_rules:
+            conditions = rule.get("conditions", [])
+            if any((c.get("key") == "room_id" and
+                    c.get("pattern") == old_room_id) for c in conditions):
+                self.move_push_rule_from_room_to_room(
+                    new_room_id, user_id, rule,
+                )
+
+    @defer.inlineCallbacks
     def bulk_get_push_rules_for_room(self, event, context):
         state_group = context.state_group
         if not state_group: