summary refs log tree commit diff
path: root/synapse/storage/appservice.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/appservice.py')
-rw-r--r--synapse/storage/appservice.py126
1 files changed, 105 insertions, 21 deletions
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index d941b1f387..97481d113b 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -15,31 +15,21 @@
 import logging
 from twisted.internet import defer
 
+from synapse.api.constants import Membership
 from synapse.api.errors import StoreError
 from synapse.appservice import ApplicationService
+from synapse.storage.roommember import RoomsForUser
 from ._base import SQLBaseStore
 
 
 logger = logging.getLogger(__name__)
 
 
-class ApplicationServiceCache(object):
-    """Caches ApplicationServices and provides utility functions on top.
-
-    This class is designed to be invoked on incoming events in order to avoid
-    hammering the database every time to extract a list of application service
-    regexes.
-    """
-
-    def __init__(self):
-        self.services = []
-
-
 class ApplicationServiceStore(SQLBaseStore):
 
     def __init__(self, hs):
         super(ApplicationServiceStore, self).__init__(hs)
-        self.cache = ApplicationServiceCache()
+        self.services_cache = []
         self.cache_defer = self._populate_cache()
 
     @defer.inlineCallbacks
@@ -56,7 +46,7 @@ class ApplicationServiceStore(SQLBaseStore):
             token,
         )
         # update cache TODO: Should this be in the txn?
-        for service in self.cache.services:
+        for service in self.services_cache:
             if service.token == token:
                 service.url = None
                 service.namespaces = None
@@ -110,13 +100,13 @@ class ApplicationServiceStore(SQLBaseStore):
         )
 
         # update cache TODO: Should this be in the txn?
-        for (index, cache_service) in enumerate(self.cache.services):
+        for (index, cache_service) in enumerate(self.services_cache):
             if service.token == cache_service.token:
-                self.cache.services[index] = service
+                self.services_cache[index] = service
                 logger.info("Updated: %s", service)
                 return
         # new entry
-        self.cache.services.append(service)
+        self.services_cache.append(service)
         logger.info("Updated(new): %s", service)
 
     def _update_app_service_txn(self, txn, service):
@@ -160,11 +150,34 @@ class ApplicationServiceStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_app_services(self):
         yield self.cache_defer  # make sure the cache is ready
-        defer.returnValue(self.cache.services)
+        defer.returnValue(self.services_cache)
+
+    @defer.inlineCallbacks
+    def get_app_service_by_user_id(self, user_id):
+        """Retrieve an application service from their user ID.
+
+        All application services have associated with them a particular user ID.
+        There is no distinguishing feature on the user ID which indicates it
+        represents an application service. This function allows you to map from
+        a user ID to an application service.
+
+        Args:
+            user_id(str): The user ID to see if it is an application service.
+        Returns:
+            synapse.appservice.ApplicationService or None.
+        """
+
+        yield self.cache_defer  # make sure the cache is ready
+
+        for service in self.services_cache:
+            if service.sender == user_id:
+                defer.returnValue(service)
+                return
+        defer.returnValue(None)
 
     @defer.inlineCallbacks
     def get_app_service_by_token(self, token, from_cache=True):
-        """Get the application service with the given token.
+        """Get the application service with the given appservice token.
 
         Args:
             token (str): The application service token.
@@ -176,7 +189,7 @@ class ApplicationServiceStore(SQLBaseStore):
         yield self.cache_defer  # make sure the cache is ready
 
         if from_cache:
-            for service in self.cache.services:
+            for service in self.services_cache:
                 if service.token == token:
                     defer.returnValue(service)
                     return
@@ -185,6 +198,77 @@ class ApplicationServiceStore(SQLBaseStore):
         # TODO: The from_cache=False impl
         # TODO: This should be JOINed with the application_services_regex table.
 
+    def get_app_service_rooms(self, service):
+        """Get a list of RoomsForUser for this application service.
+
+        Application services may be "interested" in lots of rooms depending on
+        the room ID, the room aliases, or the members in the room. This function
+        takes all of these into account and returns a list of RoomsForUser which
+        represent the entire list of room IDs that this application service
+        wants to know about.
+
+        Args:
+            service: The application service to get a room list for.
+        Returns:
+            A list of RoomsForUser.
+        """
+        return self.runInteraction(
+            "get_app_service_rooms",
+            self._get_app_service_rooms_txn,
+            service,
+        )
+
+    def _get_app_service_rooms_txn(self, txn, service):
+        # get all rooms matching the room ID regex.
+        room_entries = self._simple_select_list_txn(
+            txn=txn, table="rooms", keyvalues=None, retcols=["room_id"]
+        )
+        matching_room_list = set([
+            r["room_id"] for r in room_entries if
+            service.is_interested_in_room(r["room_id"])
+        ])
+
+        # resolve room IDs for matching room alias regex.
+        room_alias_mappings = self._simple_select_list_txn(
+            txn=txn, table="room_aliases", keyvalues=None,
+            retcols=["room_id", "room_alias"]
+        )
+        matching_room_list |= set([
+            r["room_id"] for r in room_alias_mappings if
+            service.is_interested_in_alias(r["room_alias"])
+        ])
+
+        # get all rooms for every user for this AS. This is scoped to users on
+        # this HS only.
+        user_list = self._simple_select_list_txn(
+            txn=txn, table="users", keyvalues=None, retcols=["name"]
+        )
+        user_list = [
+            u["name"] for u in user_list if
+            service.is_interested_in_user(u["name"])
+        ]
+        rooms_for_user_matching_user_id = set()  # RoomsForUser list
+        for user_id in user_list:
+            # FIXME: This assumes this store is linked with RoomMemberStore :(
+            rooms_for_user = self._get_rooms_for_user_where_membership_is_txn(
+                txn=txn,
+                user_id=user_id,
+                membership_list=[Membership.JOIN]
+            )
+            rooms_for_user_matching_user_id |= set(rooms_for_user)
+
+        # make RoomsForUser tuples for room ids and aliases which are not in the
+        # main rooms_for_user_list - e.g. they are rooms which do not have AS
+        # registered users in it.
+        known_room_ids = [r.room_id for r in rooms_for_user_matching_user_id]
+        missing_rooms_for_user = [
+            RoomsForUser(r, service.sender, "join") for r in
+            matching_room_list if r not in known_room_ids
+        ]
+        rooms_for_user_matching_user_id |= set(missing_rooms_for_user)
+
+        return rooms_for_user_matching_user_id
+
     @defer.inlineCallbacks
     def _populate_cache(self):
         """Populates the ApplicationServiceCache from the database."""
@@ -235,7 +319,7 @@ class ApplicationServiceStore(SQLBaseStore):
         # TODO get last successful txn id f.e. service
         for service in services.values():
             logger.info("Found application service: %s", service)
-            self.cache.services.append(ApplicationService(
+            self.services_cache.append(ApplicationService(
                 token=service["token"],
                 url=service["url"],
                 namespaces=service["namespaces"],