summary refs log tree commit diff
path: root/synapse/appservice/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/appservice/__init__.py')
-rw-r--r--synapse/appservice/__init__.py59
1 files changed, 51 insertions, 8 deletions
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index b989007314..57ed8a3ca2 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -12,13 +12,16 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from synapse.api.constants import EventTypes
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+import logging
+import re
+
+from six import string_types
 
 from twisted.internet import defer
 
-import logging
-import re
+from synapse.api.constants import EventTypes
+from synapse.types import GroupID, get_domain_from_id
+from synapse.util.caches.descriptors import cachedInlineCallbacks
 
 logger = logging.getLogger(__name__)
 
@@ -81,14 +84,17 @@ class ApplicationService(object):
     # values.
     NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
 
-    def __init__(self, token, url=None, namespaces=None, hs_token=None,
-                 sender=None, id=None, protocols=None, rate_limited=True):
+    def __init__(self, token, hostname, url=None, namespaces=None, hs_token=None,
+                 sender=None, id=None, protocols=None, rate_limited=True,
+                 ip_range_whitelist=None):
         self.token = token
         self.url = url
         self.hs_token = hs_token
         self.sender = sender
+        self.server_name = hostname
         self.namespaces = self._check_namespaces(namespaces)
         self.id = id
+        self.ip_range_whitelist = ip_range_whitelist
 
         if "|" in self.id:
             raise Exception("application service ID cannot contain '|' character")
@@ -125,8 +131,26 @@ class ApplicationService(object):
                     raise ValueError(
                         "Expected bool for 'exclusive' in ns '%s'" % ns
                     )
+                group_id = regex_obj.get("group_id")
+                if group_id:
+                    if not isinstance(group_id, str):
+                        raise ValueError(
+                            "Expected string for 'group_id' in ns '%s'" % ns
+                        )
+                    try:
+                        GroupID.from_string(group_id)
+                    except Exception:
+                        raise ValueError(
+                            "Expected valid group ID for 'group_id' in ns '%s'" % ns
+                        )
+
+                    if get_domain_from_id(group_id) != self.server_name:
+                        raise ValueError(
+                            "Expected 'group_id' to be this host in ns '%s'" % ns
+                        )
+
                 regex = regex_obj.get("regex")
-                if isinstance(regex, basestring):
+                if isinstance(regex, string_types):
                     regex_obj["regex"] = re.compile(regex)  # Pre-compile regex
                 else:
                     raise ValueError(
@@ -251,8 +275,27 @@ class ApplicationService(object):
             if regex_obj["exclusive"]
         ]
 
+    def get_groups_for_user(self, user_id):
+        """Get the groups that this user is associated with by this AS
+
+        Args:
+            user_id (str): The ID of the user.
+
+        Returns:
+            iterable[str]: an iterable that yields group_id strings.
+        """
+        return (
+            regex_obj["group_id"]
+            for regex_obj in self.namespaces[ApplicationService.NS_USERS]
+            if "group_id" in regex_obj and regex_obj["regex"].match(user_id)
+        )
+
     def is_rate_limited(self):
         return self.rate_limited
 
     def __str__(self):
-        return "ApplicationService: %s" % (self.__dict__,)
+        # copy dictionary and redact token fields so they don't get logged
+        dict_copy = self.__dict__.copy()
+        dict_copy["token"] = "<redacted>"
+        dict_copy["hs_token"] = "<redacted>"
+        return "ApplicationService: %s" % (dict_copy,)