summary refs log tree commit diff
path: root/synapse/storage/push_rule.py
blob: ca04f2ccee6aa131a45ff91144a2d5c9f1dc8131 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections

from ._base import SQLBaseStore, Table
from twisted.internet import defer

import logging
import copy
import json

logger = logging.getLogger(__name__)


class PushRuleStore(SQLBaseStore):
    @defer.inlineCallbacks
    def get_push_rules_for_user_name(self, user_name):
        sql = (
            "SELECT "+",".join(PushRuleTable.fields)+" "
            "FROM "+PushRuleTable.table_name+" "
            "WHERE user_name = ? "
            "ORDER BY priority_class DESC, priority DESC"
        )
        rows = yield self._execute(None, sql, user_name)

        dicts = []
        for r in rows:
            d = {}
            for i, f in enumerate(PushRuleTable.fields):
                d[f] = r[i]
            dicts.append(d)

        defer.returnValue(dicts)

    @defer.inlineCallbacks
    def add_push_rule(self, before, after, **kwargs):
        vals = copy.copy(kwargs)
        if 'conditions' in vals:
            vals['conditions'] = json.dumps(vals['conditions'])
        if 'actions' in vals:
            vals['actions'] = json.dumps(vals['actions'])
        # we could check the rest of the keys are valid column names
        # but sqlite will do that anyway so I think it's just pointless.
        if 'id' in vals:
            del vals['id']

        if before or after:
            ret = yield self.runInteraction(
                "_add_push_rule_relative_txn",
                self._add_push_rule_relative_txn,
                before=before,
                after=after,
                **vals
            )
            defer.returnValue(ret)
        else:
            ret = yield self.runInteraction(
                "_add_push_rule_highest_priority_txn",
                self._add_push_rule_highest_priority_txn,
                **vals
            )
            defer.returnValue(ret)

    def _add_push_rule_relative_txn(self, txn, user_name, **kwargs):
        after = None
        relative_to_rule = None
        if 'after' in kwargs and kwargs['after']:
            after = kwargs['after']
            relative_to_rule = after
        if 'before' in kwargs and kwargs['before']:
            relative_to_rule = kwargs['before']

        # get the priority of the rule we're inserting after/before
        sql = (
            "SELECT priority_class, priority FROM "+PushRuleTable.table_name+
            " WHERE user_name = ? and rule_id = ?"
        )
        txn.execute(sql, (user_name, relative_to_rule))
        res = txn.fetchall()
        if not res:
            raise RuleNotFoundException("before/after rule not found: %s" % (relative_to_rule))
        (priority_class, base_rule_priority) = res[0]

        if 'priority_class' in kwargs and kwargs['priority_class'] != priority_class:
            raise InconsistentRuleException(
                "Given priority class does not match class of relative rule"
            )

        new_rule = copy.copy(kwargs)
        if 'before' in new_rule:
            del new_rule['before']
        if 'after' in new_rule:
            del new_rule['after']
        new_rule['priority_class'] = priority_class
        new_rule['user_name'] = user_name

        # check if the priority before/after is free
        new_rule_priority = base_rule_priority
        if after:
            new_rule_priority -= 1
        else:
            new_rule_priority += 1

        new_rule['priority'] = new_rule_priority

        sql = (
            "SELECT COUNT(*) FROM "+PushRuleTable.table_name+
            " WHERE user_name = ? AND priority_class = ? AND priority = ?"
        )
        txn.execute(sql, (user_name, priority_class, new_rule_priority))
        res = txn.fetchall()
        num_conflicting = res[0][0]

        # if there are conflicting rules, bump everything
        if num_conflicting:
            sql = "UPDATE "+PushRuleTable.table_name+" SET priority = priority "
            if after:
                sql += "-1"
            else:
                sql += "+1"
            sql += " WHERE user_name = ? AND priority_class = ? AND priority "
            if after:
                sql += "<= ?"
            else:
                sql += ">= ?"

            txn.execute(sql, (user_name, priority_class, new_rule_priority))

        # now insert the new rule
        sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" ("
        sql += ",".join(new_rule.keys())+") VALUES ("
        sql += ", ".join(["?" for _ in new_rule.keys()])+")"

        txn.execute(sql, new_rule.values())

    def _add_push_rule_highest_priority_txn(self, txn, user_name, priority_class, **kwargs):
        # find the highest priority rule in that class
        sql = (
            "SELECT COUNT(*), MAX(priority) FROM "+PushRuleTable.table_name+
            " WHERE user_name = ? and priority_class = ?"
        )
        txn.execute(sql, (user_name, priority_class))
        res = txn.fetchall()
        (how_many, highest_prio) = res[0]

        new_prio = 0
        if how_many > 0:
            new_prio = highest_prio + 1

        # and insert the new rule
        new_rule = copy.copy(kwargs)
        if 'id' in new_rule:
            del new_rule['id']
        new_rule['user_name'] = user_name
        new_rule['priority_class'] = priority_class
        new_rule['priority'] = new_prio

        sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" ("
        sql += ",".join(new_rule.keys())+") VALUES ("
        sql += ", ".join(["?" for _ in new_rule.keys()])+")"

        txn.execute(sql, new_rule.values())

    @defer.inlineCallbacks
    def delete_push_rule(self, user_name, rule_id):
        yield self._simple_delete_one(
            PushRuleTable.table_name,
            {
                'user_name': user_name,
                'rule_id': rule_id
            }
        )

class RuleNotFoundException(Exception):
    pass


class InconsistentRuleException(Exception):
    pass


class PushRuleTable(Table):
    table_name = "push_rules"

    fields = [
        "id",
        "user_name",
        "rule_id",
        "priority_class",
        "priority",
        "conditions",
        "actions",
    ]

    EntryType = collections.namedtuple("PushRuleEntry", fields)