summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/appservice.py86
-rw-r--r--synapse/storage/event_push_actions.py5
-rw-r--r--synapse/storage/push_rule.py54
-rw-r--r--synapse/storage/pusher.py22
-rw-r--r--synapse/storage/receipts.py14
-rw-r--r--synapse/storage/registration.py26
-rw-r--r--synapse/storage/roommember.py13
7 files changed, 116 insertions, 104 deletions
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index eab58d9ce9..b5aa55c0a3 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -15,12 +15,12 @@
 import logging
 import urllib
 import yaml
-from simplejson import JSONDecodeError
 import simplejson as json
 from twisted.internet import defer
 
 from synapse.api.constants import Membership
 from synapse.appservice import ApplicationService, AppServiceTransaction
+from synapse.config._base import ConfigError
 from synapse.storage.roommember import RoomsForUser
 from synapse.types import UserID
 from ._base import SQLBaseStore
@@ -144,66 +144,9 @@ class ApplicationServiceStore(SQLBaseStore):
 
         return rooms_for_user_matching_user_id
 
-    def _parse_services_dict(self, results):
-        # SQL results in the form:
-        # [
-        #   {
-        #     'regex': "something",
-        #     'url': "something",
-        #     'namespace': enum,
-        #     'as_id': 0,
-        #     'token': "something",
-        #     'hs_token': "otherthing",
-        #     'id': 0
-        #   }
-        # ]
-        services = {}
-        for res in results:
-            as_token = res["token"]
-            if as_token is None:
-                continue
-            if as_token not in services:
-                # add the service
-                services[as_token] = {
-                    "id": res["id"],
-                    "url": res["url"],
-                    "token": as_token,
-                    "hs_token": res["hs_token"],
-                    "sender": res["sender"],
-                    "namespaces": {
-                        ApplicationService.NS_USERS: [],
-                        ApplicationService.NS_ALIASES: [],
-                        ApplicationService.NS_ROOMS: []
-                    }
-                }
-            # add the namespace regex if one exists
-            ns_int = res["namespace"]
-            if ns_int is None:
-                continue
-            try:
-                services[as_token]["namespaces"][
-                    ApplicationService.NS_LIST[ns_int]].append(
-                    json.loads(res["regex"])
-                )
-            except IndexError:
-                logger.error("Bad namespace enum '%s'. %s", ns_int, res)
-            except JSONDecodeError:
-                logger.error("Bad regex object '%s'", res["regex"])
-
-        service_list = []
-        for service in services.values():
-            service_list.append(ApplicationService(
-                token=service["token"],
-                url=service["url"],
-                namespaces=service["namespaces"],
-                hs_token=service["hs_token"],
-                sender=service["sender"],
-                id=service["id"]
-            ))
-        return service_list
-
     def _load_appservice(self, as_info):
         required_string_fields = [
+            # TODO: Add id here when it's stable to release
             "url", "as_token", "hs_token", "sender_localpart"
         ]
         for field in required_string_fields:
@@ -245,7 +188,7 @@ class ApplicationServiceStore(SQLBaseStore):
             namespaces=as_info["namespaces"],
             hs_token=as_info["hs_token"],
             sender=user_id,
-            id=as_info["as_token"]  # the token is the only unique thing here
+            id=as_info["id"] if "id" in as_info else as_info["as_token"],
         )
 
     def _populate_appservice_cache(self, config_files):
@@ -256,15 +199,38 @@ class ApplicationServiceStore(SQLBaseStore):
             )
             return
 
+        # Dicts of value -> filename
+        seen_as_tokens = {}
+        seen_ids = {}
+
         for config_file in config_files:
             try:
                 with open(config_file, 'r') as f:
                     appservice = self._load_appservice(yaml.load(f))
+                    if appservice.id in seen_ids:
+                        raise ConfigError(
+                            "Cannot reuse ID across application services: "
+                            "%s (files: %s, %s)" % (
+                                appservice.id, config_file, seen_ids[appservice.id],
+                            )
+                        )
+                    seen_ids[appservice.id] = config_file
+                    if appservice.token in seen_as_tokens:
+                        raise ConfigError(
+                            "Cannot reuse as_token across application services: "
+                            "%s (files: %s, %s)" % (
+                                appservice.token,
+                                config_file,
+                                seen_as_tokens[appservice.token],
+                            )
+                        )
+                    seen_as_tokens[appservice.token] = config_file
                     logger.info("Loaded application service: %s", appservice)
                     self.services_cache.append(appservice)
             except Exception as e:
                 logger.error("Failed to load appservice from '%s'", config_file)
                 logger.exception(e)
+                raise
 
 
 class ApplicationServiceTransactionStore(SQLBaseStore):
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index d99171ee87..6b7cebc9ce 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -17,7 +17,7 @@ from ._base import SQLBaseStore
 from twisted.internet import defer
 
 import logging
-import simplejson as json
+import ujson as json
 
 logger = logging.getLogger(__name__)
 
@@ -84,7 +84,8 @@ class EventPushActionsStore(SQLBaseStore):
             )
             )
             return [
-                {"event_id": row[0], "actions": row[1]} for row in txn.fetchall()
+                {"event_id": row[0], "actions": json.loads(row[1])}
+                for row in txn.fetchall()
             ]
 
         ret = yield self.runInteraction(
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 448009b4b6..2adfefd994 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -25,11 +25,11 @@ logger = logging.getLogger(__name__)
 
 class PushRuleStore(SQLBaseStore):
     @cachedInlineCallbacks()
-    def get_push_rules_for_user(self, user_name):
+    def get_push_rules_for_user(self, user_id):
         rows = yield self._simple_select_list(
             table="push_rules",
             keyvalues={
-                "user_name": user_name,
+                "user_name": user_id,
             },
             retcols=(
                 "user_name", "rule_id", "priority_class", "priority",
@@ -45,11 +45,11 @@ class PushRuleStore(SQLBaseStore):
         defer.returnValue(rows)
 
     @cachedInlineCallbacks()
-    def get_push_rules_enabled_for_user(self, user_name):
+    def get_push_rules_enabled_for_user(self, user_id):
         results = yield self._simple_select_list(
             table="push_rules_enable",
             keyvalues={
-                'user_name': user_name
+                'user_name': user_id
             },
             retcols=(
                 "user_name", "rule_id", "enabled",
@@ -122,7 +122,7 @@ class PushRuleStore(SQLBaseStore):
             )
             defer.returnValue(ret)
 
-    def _add_push_rule_relative_txn(self, txn, user_name, **kwargs):
+    def _add_push_rule_relative_txn(self, txn, user_id, **kwargs):
         after = kwargs.pop("after", None)
         relative_to_rule = kwargs.pop("before", after)
 
@@ -130,7 +130,7 @@ class PushRuleStore(SQLBaseStore):
             txn,
             table="push_rules",
             keyvalues={
-                "user_name": user_name,
+                "user_name": user_id,
                 "rule_id": relative_to_rule,
             },
             retcols=["priority_class", "priority"],
@@ -154,7 +154,7 @@ class PushRuleStore(SQLBaseStore):
         new_rule.pop("before", None)
         new_rule.pop("after", None)
         new_rule['priority_class'] = priority_class
-        new_rule['user_name'] = user_name
+        new_rule['user_name'] = user_id
         new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
 
         # check if the priority before/after is free
@@ -170,7 +170,7 @@ class PushRuleStore(SQLBaseStore):
             "SELECT COUNT(*) FROM push_rules"
             " WHERE user_name = ? AND priority_class = ? AND priority = ?"
         )
-        txn.execute(sql, (user_name, priority_class, new_rule_priority))
+        txn.execute(sql, (user_id, priority_class, new_rule_priority))
         res = txn.fetchall()
         num_conflicting = res[0][0]
 
@@ -187,14 +187,14 @@ class PushRuleStore(SQLBaseStore):
             else:
                 sql += ">= ?"
 
-            txn.execute(sql, (user_name, priority_class, new_rule_priority))
+            txn.execute(sql, (user_id, priority_class, new_rule_priority))
 
         txn.call_after(
-            self.get_push_rules_for_user.invalidate, (user_name,)
+            self.get_push_rules_for_user.invalidate, (user_id,)
         )
 
         txn.call_after(
-            self.get_push_rules_enabled_for_user.invalidate, (user_name,)
+            self.get_push_rules_enabled_for_user.invalidate, (user_id,)
         )
 
         self._simple_insert_txn(
@@ -203,14 +203,14 @@ class PushRuleStore(SQLBaseStore):
             values=new_rule,
         )
 
-    def _add_push_rule_highest_priority_txn(self, txn, user_name,
+    def _add_push_rule_highest_priority_txn(self, txn, user_id,
                                             priority_class, **kwargs):
         # find the highest priority rule in that class
         sql = (
             "SELECT COUNT(*), MAX(priority) FROM push_rules"
             " WHERE user_name = ? and priority_class = ?"
         )
-        txn.execute(sql, (user_name, priority_class))
+        txn.execute(sql, (user_id, priority_class))
         res = txn.fetchall()
         (how_many, highest_prio) = res[0]
 
@@ -221,15 +221,15 @@ class PushRuleStore(SQLBaseStore):
         # and insert the new rule
         new_rule = kwargs
         new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
-        new_rule['user_name'] = user_name
+        new_rule['user_name'] = user_id
         new_rule['priority_class'] = priority_class
         new_rule['priority'] = new_prio
 
         txn.call_after(
-            self.get_push_rules_for_user.invalidate, (user_name,)
+            self.get_push_rules_for_user.invalidate, (user_id,)
         )
         txn.call_after(
-            self.get_push_rules_enabled_for_user.invalidate, (user_name,)
+            self.get_push_rules_enabled_for_user.invalidate, (user_id,)
         )
 
         self._simple_insert_txn(
@@ -239,48 +239,48 @@ class PushRuleStore(SQLBaseStore):
         )
 
     @defer.inlineCallbacks
-    def delete_push_rule(self, user_name, rule_id):
+    def delete_push_rule(self, user_id, rule_id):
         """
         Delete a push rule. Args specify the row to be deleted and can be
         any of the columns in the push_rule table, but below are the
         standard ones
 
         Args:
-            user_name (str): The matrix ID of the push rule owner
+            user_id (str): The matrix ID of the push rule owner
             rule_id (str): The rule_id of the rule to be deleted
         """
         yield self._simple_delete_one(
             "push_rules",
-            {'user_name': user_name, 'rule_id': rule_id},
+            {'user_name': user_id, 'rule_id': rule_id},
             desc="delete_push_rule",
         )
 
-        self.get_push_rules_for_user.invalidate((user_name,))
-        self.get_push_rules_enabled_for_user.invalidate((user_name,))
+        self.get_push_rules_for_user.invalidate((user_id,))
+        self.get_push_rules_enabled_for_user.invalidate((user_id,))
 
     @defer.inlineCallbacks
-    def set_push_rule_enabled(self, user_name, rule_id, enabled):
+    def set_push_rule_enabled(self, user_id, rule_id, enabled):
         ret = yield self.runInteraction(
             "_set_push_rule_enabled_txn",
             self._set_push_rule_enabled_txn,
-            user_name, rule_id, enabled
+            user_id, rule_id, enabled
         )
         defer.returnValue(ret)
 
-    def _set_push_rule_enabled_txn(self, txn, user_name, rule_id, enabled):
+    def _set_push_rule_enabled_txn(self, txn, user_id, rule_id, enabled):
         new_id = self._push_rules_enable_id_gen.get_next_txn(txn)
         self._simple_upsert_txn(
             txn,
             "push_rules_enable",
-            {'user_name': user_name, 'rule_id': rule_id},
+            {'user_name': user_id, 'rule_id': rule_id},
             {'enabled': 1 if enabled else 0},
             {'id': new_id},
         )
         txn.call_after(
-            self.get_push_rules_for_user.invalidate, (user_name,)
+            self.get_push_rules_for_user.invalidate, (user_id,)
         )
         txn.call_after(
-            self.get_push_rules_enabled_for_user.invalidate, (user_name,)
+            self.get_push_rules_enabled_for_user.invalidate, (user_id,)
         )
 
 
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 2b90d6c622..8ec706178a 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -80,7 +80,7 @@ class PusherStore(SQLBaseStore):
         defer.returnValue(rows)
 
     @defer.inlineCallbacks
-    def add_pusher(self, user_name, access_token, profile_tag, kind, app_id,
+    def add_pusher(self, user_id, access_token, profile_tag, kind, app_id,
                    app_display_name, device_display_name,
                    pushkey, pushkey_ts, lang, data):
         try:
@@ -90,7 +90,7 @@ class PusherStore(SQLBaseStore):
                 dict(
                     app_id=app_id,
                     pushkey=pushkey,
-                    user_name=user_name,
+                    user_name=user_id,
                 ),
                 dict(
                     access_token=access_token,
@@ -112,38 +112,38 @@ class PusherStore(SQLBaseStore):
             raise StoreError(500, "Problem creating pusher.")
 
     @defer.inlineCallbacks
-    def delete_pusher_by_app_id_pushkey_user_name(self, app_id, pushkey, user_name):
+    def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
         yield self._simple_delete_one(
             "pushers",
-            {"app_id": app_id, "pushkey": pushkey, 'user_name': user_name},
-            desc="delete_pusher_by_app_id_pushkey_user_name",
+            {"app_id": app_id, "pushkey": pushkey, 'user_name': user_id},
+            desc="delete_pusher_by_app_id_pushkey_user_id",
         )
 
     @defer.inlineCallbacks
-    def update_pusher_last_token(self, app_id, pushkey, user_name, last_token):
+    def update_pusher_last_token(self, app_id, pushkey, user_id, last_token):
         yield self._simple_update_one(
             "pushers",
-            {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name},
+            {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
             {'last_token': last_token},
             desc="update_pusher_last_token",
         )
 
     @defer.inlineCallbacks
-    def update_pusher_last_token_and_success(self, app_id, pushkey, user_name,
+    def update_pusher_last_token_and_success(self, app_id, pushkey, user_id,
                                              last_token, last_success):
         yield self._simple_update_one(
             "pushers",
-            {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name},
+            {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
             {'last_token': last_token, 'last_success': last_success},
             desc="update_pusher_last_token_and_success",
         )
 
     @defer.inlineCallbacks
-    def update_pusher_failing_since(self, app_id, pushkey, user_name,
+    def update_pusher_failing_since(self, app_id, pushkey, user_id,
                                     failing_since):
         yield self._simple_update_one(
             "pushers",
-            {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name},
+            {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
             {'failing_since': failing_since},
             desc="update_pusher_failing_since",
         )
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 21cf88b3da..c80e576620 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
 from synapse.util.caches import cache_counter, caches_by_name
 
 from twisted.internet import defer
@@ -33,6 +33,18 @@ class ReceiptsStore(SQLBaseStore):
 
         self._receipts_stream_cache = _RoomStreamChangeCache()
 
+    @cached(num_args=2)
+    def get_receipts_for_room(self, room_id, receipt_type):
+        return self._simple_select_list(
+            table="receipts_linearized",
+            keyvalues={
+                "room_id": room_id,
+                "receipt_type": receipt_type,
+            },
+            retcols=("user_id", "event_id"),
+            desc="get_receipts_for_room",
+        )
+
     @defer.inlineCallbacks
     def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
         """Get receipts for multiple rooms for sending to clients.
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 999b710fbb..70cde0d04d 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -18,7 +18,7 @@ from twisted.internet import defer
 from synapse.api.errors import StoreError, Codes
 
 from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
 
 
 class RegistrationStore(SQLBaseStore):
@@ -256,10 +256,10 @@ class RegistrationStore(SQLBaseStore):
         defer.returnValue(res if res else False)
 
     @cachedInlineCallbacks()
-    def is_guest(self, user):
+    def is_guest(self, user_id):
         res = yield self._simple_select_one_onecol(
             table="users",
-            keyvalues={"name": user.to_string()},
+            keyvalues={"name": user_id},
             retcol="is_guest",
             allow_none=True,
             desc="is_guest",
@@ -267,6 +267,26 @@ class RegistrationStore(SQLBaseStore):
 
         defer.returnValue(res if res else False)
 
+    @cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1,
+                inlineCallbacks=True)
+    def are_guests(self, user_ids):
+        sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % (
+            ",".join("?" for _ in user_ids),
+        )
+
+        rows = yield self._execute(
+            "are_guests", self.cursor_to_dict, sql, *user_ids
+        )
+
+        result = {user_id: False for user_id in user_ids}
+
+        result.update({
+            row["name"]: bool(row["is_guest"])
+            for row in rows
+        })
+
+        defer.returnValue(result)
+
     def _query_for_auth(self, txn, token):
         sql = (
             "SELECT users.name, users.is_guest, access_tokens.id as token_id"
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 7d3ce4579d..68ac88905f 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -287,6 +287,7 @@ class RoomMemberStore(SQLBaseStore):
             txn.execute(sql, (user_id, room_id))
         yield self.runInteraction("forget_membership", f)
         self.was_forgotten_at.invalidate_all()
+        self.who_forgot_in_room.invalidate_all()
         self.did_forget.invalidate((user_id, room_id))
 
     @cachedInlineCallbacks(num_args=2)
@@ -336,3 +337,15 @@ class RoomMemberStore(SQLBaseStore):
             return rows[0][0]
         forgot = yield self.runInteraction("did_forget_membership_at", f)
         defer.returnValue(forgot == 1)
+
+    @cached()
+    def who_forgot_in_room(self, room_id):
+        return self._simple_select_list(
+            table="room_memberships",
+            retcols=("user_id", "event_id"),
+            keyvalues={
+                "room_id": room_id,
+                "forgotten": 1,
+            },
+            desc="who_forgot"
+        )