diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index b989007314..57ed8a3ca2 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -12,13 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.api.constants import EventTypes
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+import logging
+import re
+
+from six import string_types
from twisted.internet import defer
-import logging
-import re
+from synapse.api.constants import EventTypes
+from synapse.types import GroupID, get_domain_from_id
+from synapse.util.caches.descriptors import cachedInlineCallbacks
logger = logging.getLogger(__name__)
@@ -81,14 +84,17 @@ class ApplicationService(object):
# values.
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
- def __init__(self, token, url=None, namespaces=None, hs_token=None,
- sender=None, id=None, protocols=None, rate_limited=True):
+ def __init__(self, token, hostname, url=None, namespaces=None, hs_token=None,
+ sender=None, id=None, protocols=None, rate_limited=True,
+ ip_range_whitelist=None):
self.token = token
self.url = url
self.hs_token = hs_token
self.sender = sender
+ self.server_name = hostname
self.namespaces = self._check_namespaces(namespaces)
self.id = id
+ self.ip_range_whitelist = ip_range_whitelist
if "|" in self.id:
raise Exception("application service ID cannot contain '|' character")
@@ -125,8 +131,26 @@ class ApplicationService(object):
raise ValueError(
"Expected bool for 'exclusive' in ns '%s'" % ns
)
+ group_id = regex_obj.get("group_id")
+ if group_id:
+ if not isinstance(group_id, str):
+ raise ValueError(
+ "Expected string for 'group_id' in ns '%s'" % ns
+ )
+ try:
+ GroupID.from_string(group_id)
+ except Exception:
+ raise ValueError(
+ "Expected valid group ID for 'group_id' in ns '%s'" % ns
+ )
+
+ if get_domain_from_id(group_id) != self.server_name:
+ raise ValueError(
+ "Expected 'group_id' to be this host in ns '%s'" % ns
+ )
+
regex = regex_obj.get("regex")
- if isinstance(regex, basestring):
+ if isinstance(regex, string_types):
regex_obj["regex"] = re.compile(regex) # Pre-compile regex
else:
raise ValueError(
@@ -251,8 +275,27 @@ class ApplicationService(object):
if regex_obj["exclusive"]
]
+ def get_groups_for_user(self, user_id):
+ """Get the groups that this user is associated with by this AS
+
+ Args:
+ user_id (str): The ID of the user.
+
+ Returns:
+ iterable[str]: an iterable that yields group_id strings.
+ """
+ return (
+ regex_obj["group_id"]
+ for regex_obj in self.namespaces[ApplicationService.NS_USERS]
+ if "group_id" in regex_obj and regex_obj["regex"].match(user_id)
+ )
+
def is_rate_limited(self):
return self.rate_limited
def __str__(self):
- return "ApplicationService: %s" % (self.__dict__,)
+ # copy dictionary and redact token fields so they don't get logged
+ dict_copy = self.__dict__.copy()
+ dict_copy["token"] = "<redacted>"
+ dict_copy["hs_token"] = "<redacted>"
+ return "ApplicationService: %s" % (dict_copy,)
|