summary refs log tree commit diff
path: root/synapse/storage/appservice.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/appservice.py')
-rw-r--r--synapse/storage/appservice.py126
1 files changed, 17 insertions, 109 deletions
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index 371600eebb..d1ee533fac 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -13,16 +13,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-import urllib
-import yaml
 import simplejson as json
 from twisted.internet import defer
 
 from synapse.api.constants import Membership
-from synapse.appservice import ApplicationService, AppServiceTransaction
-from synapse.config._base import ConfigError
+from synapse.appservice import AppServiceTransaction
+from synapse.config.appservice import load_appservices
 from synapse.storage.roommember import RoomsForUser
-from synapse.types import UserID
 from ._base import SQLBaseStore
 
 
@@ -34,7 +31,7 @@ class ApplicationServiceStore(SQLBaseStore):
     def __init__(self, hs):
         super(ApplicationServiceStore, self).__init__(hs)
         self.hostname = hs.hostname
-        self.services_cache = ApplicationServiceStore.load_appservices(
+        self.services_cache = load_appservices(
             hs.hostname,
             hs.config.app_service_config_files
         )
@@ -144,102 +141,6 @@ class ApplicationServiceStore(SQLBaseStore):
 
         return rooms_for_user_matching_user_id
 
-    @classmethod
-    def _load_appservice(cls, hostname, as_info, config_filename):
-        required_string_fields = [
-            "id", "url", "as_token", "hs_token", "sender_localpart"
-        ]
-        for field in required_string_fields:
-            if not isinstance(as_info.get(field), basestring):
-                raise KeyError("Required string field: '%s' (%s)" % (
-                    field, config_filename,
-                ))
-
-        localpart = as_info["sender_localpart"]
-        if urllib.quote(localpart) != localpart:
-            raise ValueError(
-                "sender_localpart needs characters which are not URL encoded."
-            )
-        user = UserID(localpart, hostname)
-        user_id = user.to_string()
-
-        # namespace checks
-        if not isinstance(as_info.get("namespaces"), dict):
-            raise KeyError("Requires 'namespaces' object.")
-        for ns in ApplicationService.NS_LIST:
-            # specific namespaces are optional
-            if ns in as_info["namespaces"]:
-                # expect a list of dicts with exclusive and regex keys
-                for regex_obj in as_info["namespaces"][ns]:
-                    if not isinstance(regex_obj, dict):
-                        raise ValueError(
-                            "Expected namespace entry in %s to be an object,"
-                            " but got %s", ns, regex_obj
-                        )
-                    if not isinstance(regex_obj.get("regex"), basestring):
-                        raise ValueError(
-                            "Missing/bad type 'regex' key in %s", regex_obj
-                        )
-                    if not isinstance(regex_obj.get("exclusive"), bool):
-                        raise ValueError(
-                            "Missing/bad type 'exclusive' key in %s", regex_obj
-                        )
-        return ApplicationService(
-            token=as_info["as_token"],
-            url=as_info["url"],
-            namespaces=as_info["namespaces"],
-            hs_token=as_info["hs_token"],
-            sender=user_id,
-            id=as_info["id"],
-        )
-
-    @classmethod
-    def load_appservices(cls, hostname, config_files):
-        """Returns a list of Application Services from the config files."""
-        if not isinstance(config_files, list):
-            logger.warning(
-                "Expected %s to be a list of AS config files.", config_files
-            )
-            return []
-
-        # Dicts of value -> filename
-        seen_as_tokens = {}
-        seen_ids = {}
-
-        appservices = []
-
-        for config_file in config_files:
-            try:
-                with open(config_file, 'r') as f:
-                    appservice = ApplicationServiceStore._load_appservice(
-                        hostname, yaml.load(f), config_file
-                    )
-                    if appservice.id in seen_ids:
-                        raise ConfigError(
-                            "Cannot reuse ID across application services: "
-                            "%s (files: %s, %s)" % (
-                                appservice.id, config_file, seen_ids[appservice.id],
-                            )
-                        )
-                    seen_ids[appservice.id] = config_file
-                    if appservice.token in seen_as_tokens:
-                        raise ConfigError(
-                            "Cannot reuse as_token across application services: "
-                            "%s (files: %s, %s)" % (
-                                appservice.token,
-                                config_file,
-                                seen_as_tokens[appservice.token],
-                            )
-                        )
-                    seen_as_tokens[appservice.token] = config_file
-                    logger.info("Loaded application service: %s", appservice)
-                    appservices.append(appservice)
-            except Exception as e:
-                logger.error("Failed to load appservice from '%s'", config_file)
-                logger.exception(e)
-                raise
-        return appservices
-
 
 class ApplicationServiceTransactionStore(SQLBaseStore):
 
@@ -397,6 +298,7 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
             dict(txn_id=txn_id, as_id=service.id)
         )
 
+    @defer.inlineCallbacks
     def get_oldest_unsent_txn(self, service):
         """Get the oldest transaction which has not been sent for this
         service.
@@ -407,12 +309,23 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
             A Deferred which resolves to an AppServiceTransaction or
             None.
         """
-        return self.runInteraction(
+        entry = yield self.runInteraction(
             "get_oldest_unsent_appservice_txn",
             self._get_oldest_unsent_txn,
             service
         )
 
+        if not entry:
+            defer.returnValue(None)
+
+        event_ids = json.loads(entry["event_ids"])
+
+        events = yield self._get_events(event_ids)
+
+        defer.returnValue(AppServiceTransaction(
+            service=service, id=entry["txn_id"], events=events
+        ))
+
     def _get_oldest_unsent_txn(self, txn, service):
         # Monotonically increasing txn ids, so just select the smallest
         # one in the txns table (we delete them when they are sent)
@@ -427,12 +340,7 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
 
         entry = rows[0]
 
-        event_ids = json.loads(entry["event_ids"])
-        events = self._get_events_txn(txn, event_ids)
-
-        return AppServiceTransaction(
-            service=service, id=entry["txn_id"], events=events
-        )
+        return entry
 
     def _get_last_txn(self, txn, service_id):
         txn.execute(