summary refs log tree commit diff
path: root/synapse/appservice
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/appservice')
-rw-r--r--synapse/appservice/__init__.py37
1 files changed, 36 insertions, 1 deletions
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index b989007314..d5a7a5ce2f 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 from synapse.api.constants import EventTypes
 from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.types import GroupID, get_domain_from_id
 
 from twisted.internet import defer
 
@@ -81,12 +82,13 @@ class ApplicationService(object):
     # values.
     NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
 
-    def __init__(self, token, url=None, namespaces=None, hs_token=None,
+    def __init__(self, token, hostname, url=None, namespaces=None, hs_token=None,
                  sender=None, id=None, protocols=None, rate_limited=True):
         self.token = token
         self.url = url
         self.hs_token = hs_token
         self.sender = sender
+        self.server_name = hostname
         self.namespaces = self._check_namespaces(namespaces)
         self.id = id
 
@@ -125,6 +127,24 @@ class ApplicationService(object):
                     raise ValueError(
                         "Expected bool for 'exclusive' in ns '%s'" % ns
                     )
+                group_id = regex_obj.get("group_id")
+                if group_id:
+                    if not isinstance(group_id, str):
+                        raise ValueError(
+                            "Expected string for 'group_id' in ns '%s'" % ns
+                        )
+                    try:
+                        GroupID.from_string(group_id)
+                    except Exception:
+                        raise ValueError(
+                            "Expected valid group ID for 'group_id' in ns '%s'" % ns
+                        )
+
+                    if get_domain_from_id(group_id) != self.server_name:
+                        raise ValueError(
+                            "Expected 'group_id' to be this host in ns '%s'" % ns
+                        )
+
                 regex = regex_obj.get("regex")
                 if isinstance(regex, basestring):
                     regex_obj["regex"] = re.compile(regex)  # Pre-compile regex
@@ -251,6 +271,21 @@ class ApplicationService(object):
             if regex_obj["exclusive"]
         ]
 
+    def get_groups_for_user(self, user_id):
+        """Get the groups that this user is associated with by this AS
+
+        Args:
+            user_id (str): The ID of the user.
+
+        Returns:
+            iterable[str]: an iterable that yields group_id strings.
+        """
+        return (
+            regex_obj["group_id"]
+            for regex_obj in self.namespaces[ApplicationService.NS_USERS]
+            if "group_id" in regex_obj and regex_obj["regex"].match(user_id)
+        )
+
     def is_rate_limited(self):
         return self.rate_limited