summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/auth.py5
-rw-r--r--synapse/config/registration.py2
-rw-r--r--synapse/config/tracer.py37
-rw-r--r--synapse/handlers/account_validity.py55
-rw-r--r--synapse/handlers/send_email.py98
-rw-r--r--synapse/push/mailer.py53
-rw-r--r--synapse/server.py5
-rw-r--r--synapse/storage/_base.py2
-rw-r--r--synapse/storage/databases/main/devices.py2
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py4
-rw-r--r--synapse/storage/databases/main/registration.py3
-rw-r--r--synapse/storage/databases/main/user_erasure_store.py13
-rw-r--r--synapse/util/caches/descriptors.py14
-rw-r--r--synapse/util/stringutils.py23
14 files changed, 188 insertions, 128 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index efc926d094..458306eba5 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -87,6 +87,7 @@ class Auth:
         )
         self._track_appservice_user_ips = hs.config.track_appservice_user_ips
         self._macaroon_secret_key = hs.config.macaroon_secret_key
+        self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
 
     async def check_from_context(
         self, room_version: str, event, context, do_sig_check=True
@@ -208,6 +209,8 @@ class Auth:
                 opentracing.set_tag("authenticated_entity", user_id)
                 opentracing.set_tag("user_id", user_id)
                 opentracing.set_tag("appservice_id", app_service.id)
+                if user_id in self._force_tracing_for_users:
+                    opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
 
                 return requester
 
@@ -260,6 +263,8 @@ class Auth:
             opentracing.set_tag("user_id", user_info.user_id)
             if device_id:
                 opentracing.set_tag("device_id", device_id)
+            if user_info.token_owner in self._force_tracing_for_users:
+                opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
 
             return requester
         except KeyError:
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index e6f52b4f40..d9dc55a0c3 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -349,4 +349,4 @@ class RegistrationConfig(Config):
 
     def read_arguments(self, args):
         if args.enable_registration is not None:
-            self.enable_registration = bool(strtobool(str(args.enable_registration)))
+            self.enable_registration = strtobool(str(args.enable_registration))
diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py
index db22b5b19f..d0ea17261f 100644
--- a/synapse/config/tracer.py
+++ b/synapse/config/tracer.py
@@ -12,6 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Set
+
 from synapse.python_dependencies import DependencyException, check_requirements
 
 from ._base import Config, ConfigError
@@ -32,6 +34,8 @@ class TracerConfig(Config):
             {"sampler": {"type": "const", "param": 1}, "logging": False},
         )
 
+        self.force_tracing_for_users: Set[str] = set()
+
         if not self.opentracer_enabled:
             return
 
@@ -48,6 +52,19 @@ class TracerConfig(Config):
         if not isinstance(self.opentracer_whitelist, list):
             raise ConfigError("Tracer homeserver_whitelist config is malformed")
 
+        force_tracing_for_users = opentracing_config.get("force_tracing_for_users", [])
+        if not isinstance(force_tracing_for_users, list):
+            raise ConfigError(
+                "Expected a list", ("opentracing", "force_tracing_for_users")
+            )
+        for i, u in enumerate(force_tracing_for_users):
+            if not isinstance(u, str):
+                raise ConfigError(
+                    "Expected a string",
+                    ("opentracing", "force_tracing_for_users", f"index {i}"),
+                )
+            self.force_tracing_for_users.add(u)
+
     def generate_config_section(cls, **kwargs):
         return """\
         ## Opentracing ##
@@ -64,7 +81,8 @@ class TracerConfig(Config):
             #enabled: true
 
             # The list of homeservers we wish to send and receive span contexts and span baggage.
-            # See docs/opentracing.rst
+            # See docs/opentracing.rst.
+            #
             # This is a list of regexes which are matched against the server_name of the
             # homeserver.
             #
@@ -73,19 +91,26 @@ class TracerConfig(Config):
             #homeserver_whitelist:
             #  - ".*"
 
+            # A list of the matrix IDs of users whose requests will always be traced,
+            # even if the tracing system would otherwise drop the traces due to
+            # probabilistic sampling.
+            #
+            # By default, the list is empty.
+            #
+            #force_tracing_for_users:
+            #  - "@user1:server_name"
+            #  - "@user2:server_name"
+
             # Jaeger can be configured to sample traces at different rates.
             # All configuration options provided by Jaeger can be set here.
-            # Jaeger's configuration mostly related to trace sampling which
+            # Jaeger's configuration is mostly related to trace sampling which
             # is documented here:
-            # https://www.jaegertracing.io/docs/1.13/sampling/.
+            # https://www.jaegertracing.io/docs/latest/sampling/.
             #
             #jaeger_config:
             #  sampler:
             #    type: const
             #    param: 1
-
-            #  Logging whether spans were started and reported
-            #
             #  logging:
             #    false
         """
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 5b927f10b3..d752cf34f0 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -15,12 +15,9 @@
 import email.mime.multipart
 import email.utils
 import logging
-from email.mime.multipart import MIMEMultipart
-from email.mime.text import MIMEText
 from typing import TYPE_CHECKING, List, Optional, Tuple
 
 from synapse.api.errors import StoreError, SynapseError
-from synapse.logging.context import make_deferred_yieldable
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.types import UserID
 from synapse.util import stringutils
@@ -36,9 +33,11 @@ class AccountValidityHandler:
         self.hs = hs
         self.config = hs.config
         self.store = self.hs.get_datastore()
-        self.sendmail = self.hs.get_sendmail()
+        self.send_email_handler = self.hs.get_send_email_handler()
         self.clock = self.hs.get_clock()
 
+        self._app_name = self.hs.config.email_app_name
+
         self._account_validity_enabled = (
             hs.config.account_validity.account_validity_enabled
         )
@@ -63,23 +62,10 @@ class AccountValidityHandler:
             self._template_text = (
                 hs.config.account_validity.account_validity_template_text
             )
-            account_validity_renew_email_subject = (
+            self._renew_email_subject = (
                 hs.config.account_validity.account_validity_renew_email_subject
             )
 
-            try:
-                app_name = hs.config.email_app_name
-
-                self._subject = account_validity_renew_email_subject % {"app": app_name}
-
-                self._from_string = hs.config.email_notif_from % {"app": app_name}
-            except Exception:
-                # If substitution failed, fall back to the bare strings.
-                self._subject = account_validity_renew_email_subject
-                self._from_string = hs.config.email_notif_from
-
-            self._raw_from = email.utils.parseaddr(self._from_string)[1]
-
             # Check the renewal emails to send and send them every 30min.
             if hs.config.run_background_tasks:
                 self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
@@ -159,38 +145,17 @@ class AccountValidityHandler:
         }
 
         html_text = self._template_html.render(**template_vars)
-        html_part = MIMEText(html_text, "html", "utf8")
-
         plain_text = self._template_text.render(**template_vars)
-        text_part = MIMEText(plain_text, "plain", "utf8")
 
         for address in addresses:
             raw_to = email.utils.parseaddr(address)[1]
 
-            multipart_msg = MIMEMultipart("alternative")
-            multipart_msg["Subject"] = self._subject
-            multipart_msg["From"] = self._from_string
-            multipart_msg["To"] = address
-            multipart_msg["Date"] = email.utils.formatdate()
-            multipart_msg["Message-ID"] = email.utils.make_msgid()
-            multipart_msg.attach(text_part)
-            multipart_msg.attach(html_part)
-
-            logger.info("Sending renewal email to %s", address)
-
-            await make_deferred_yieldable(
-                self.sendmail(
-                    self.hs.config.email_smtp_host,
-                    self._raw_from,
-                    raw_to,
-                    multipart_msg.as_string().encode("utf8"),
-                    reactor=self.hs.get_reactor(),
-                    port=self.hs.config.email_smtp_port,
-                    requireAuthentication=self.hs.config.email_smtp_user is not None,
-                    username=self.hs.config.email_smtp_user,
-                    password=self.hs.config.email_smtp_pass,
-                    requireTransportSecurity=self.hs.config.require_transport_security,
-                )
+            await self.send_email_handler.send_email(
+                email_address=raw_to,
+                subject=self._renew_email_subject,
+                app_name=self._app_name,
+                html=html_text,
+                text=plain_text,
             )
 
         await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py
new file mode 100644
index 0000000000..e9f6aef06f
--- /dev/null
+++ b/synapse/handlers/send_email.py
@@ -0,0 +1,98 @@
+# Copyright 2021 The Matrix.org C.I.C. Foundation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import email.utils
+import logging
+from email.mime.multipart import MIMEMultipart
+from email.mime.text import MIMEText
+from typing import TYPE_CHECKING
+
+from synapse.logging.context import make_deferred_yieldable
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class SendEmailHandler:
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+
+        self._sendmail = hs.get_sendmail()
+        self._reactor = hs.get_reactor()
+
+        self._from = hs.config.email.email_notif_from
+        self._smtp_host = hs.config.email.email_smtp_host
+        self._smtp_port = hs.config.email.email_smtp_port
+        self._smtp_user = hs.config.email.email_smtp_user
+        self._smtp_pass = hs.config.email.email_smtp_pass
+        self._require_transport_security = hs.config.email.require_transport_security
+
+    async def send_email(
+        self,
+        email_address: str,
+        subject: str,
+        app_name: str,
+        html: str,
+        text: str,
+    ) -> None:
+        """Send a multipart email with the given information.
+
+        Args:
+            email_address: The address to send the email to.
+            subject: The email's subject.
+            app_name: The app name to include in the From header.
+            html: The HTML content to include in the email.
+            text: The plain text content to include in the email.
+        """
+        try:
+            from_string = self._from % {"app": app_name}
+        except (KeyError, TypeError):
+            from_string = self._from
+
+        raw_from = email.utils.parseaddr(from_string)[1]
+        raw_to = email.utils.parseaddr(email_address)[1]
+
+        if raw_to == "":
+            raise RuntimeError("Invalid 'to' address")
+
+        html_part = MIMEText(html, "html", "utf8")
+        text_part = MIMEText(text, "plain", "utf8")
+
+        multipart_msg = MIMEMultipart("alternative")
+        multipart_msg["Subject"] = subject
+        multipart_msg["From"] = from_string
+        multipart_msg["To"] = email_address
+        multipart_msg["Date"] = email.utils.formatdate()
+        multipart_msg["Message-ID"] = email.utils.make_msgid()
+        multipart_msg.attach(text_part)
+        multipart_msg.attach(html_part)
+
+        logger.info("Sending email to %s" % email_address)
+
+        await make_deferred_yieldable(
+            self._sendmail(
+                self._smtp_host,
+                raw_from,
+                raw_to,
+                multipart_msg.as_string().encode("utf8"),
+                reactor=self._reactor,
+                port=self._smtp_port,
+                requireAuthentication=self._smtp_user is not None,
+                username=self._smtp_user,
+                password=self._smtp_pass,
+                requireTransportSecurity=self._require_transport_security,
+            )
+        )
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index c4b43b0d3f..5f9ea5003a 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -12,12 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import email.mime.multipart
-import email.utils
 import logging
 import urllib.parse
-from email.mime.multipart import MIMEMultipart
-from email.mime.text import MIMEText
 from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar
 
 import bleach
@@ -27,7 +23,6 @@ from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import StoreError
 from synapse.config.emailconfig import EmailSubjectConfig
 from synapse.events import EventBase
-from synapse.logging.context import make_deferred_yieldable
 from synapse.push.presentable_names import (
     calculate_room_name,
     descriptor_from_member_events,
@@ -108,7 +103,7 @@ class Mailer:
         self.template_html = template_html
         self.template_text = template_text
 
-        self.sendmail = self.hs.get_sendmail()
+        self.send_email_handler = hs.get_send_email_handler()
         self.store = self.hs.get_datastore()
         self.state_store = self.hs.get_storage().state
         self.macaroon_gen = self.hs.get_macaroon_generator()
@@ -310,17 +305,6 @@ class Mailer:
         self, email_address: str, subject: str, extra_template_vars: Dict[str, Any]
     ) -> None:
         """Send an email with the given information and template text"""
-        try:
-            from_string = self.hs.config.email_notif_from % {"app": self.app_name}
-        except TypeError:
-            from_string = self.hs.config.email_notif_from
-
-        raw_from = email.utils.parseaddr(from_string)[1]
-        raw_to = email.utils.parseaddr(email_address)[1]
-
-        if raw_to == "":
-            raise RuntimeError("Invalid 'to' address")
-
         template_vars = {
             "app_name": self.app_name,
             "server_name": self.hs.config.server.server_name,
@@ -329,35 +313,14 @@ class Mailer:
         template_vars.update(extra_template_vars)
 
         html_text = self.template_html.render(**template_vars)
-        html_part = MIMEText(html_text, "html", "utf8")
-
         plain_text = self.template_text.render(**template_vars)
-        text_part = MIMEText(plain_text, "plain", "utf8")
-
-        multipart_msg = MIMEMultipart("alternative")
-        multipart_msg["Subject"] = subject
-        multipart_msg["From"] = from_string
-        multipart_msg["To"] = email_address
-        multipart_msg["Date"] = email.utils.formatdate()
-        multipart_msg["Message-ID"] = email.utils.make_msgid()
-        multipart_msg.attach(text_part)
-        multipart_msg.attach(html_part)
-
-        logger.info("Sending email to %s" % email_address)
-
-        await make_deferred_yieldable(
-            self.sendmail(
-                self.hs.config.email_smtp_host,
-                raw_from,
-                raw_to,
-                multipart_msg.as_string().encode("utf8"),
-                reactor=self.hs.get_reactor(),
-                port=self.hs.config.email_smtp_port,
-                requireAuthentication=self.hs.config.email_smtp_user is not None,
-                username=self.hs.config.email_smtp_user,
-                password=self.hs.config.email_smtp_pass,
-                requireTransportSecurity=self.hs.config.require_transport_security,
-            )
+
+        await self.send_email_handler.send_email(
+            email_address=email_address,
+            subject=subject,
+            app_name=self.app_name,
+            html=html_text,
+            text=plain_text,
         )
 
     async def _get_room_vars(
diff --git a/synapse/server.py b/synapse/server.py
index 2337d2d9b4..fec0024c89 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -104,6 +104,7 @@ from synapse.handlers.room_list import RoomListHandler
 from synapse.handlers.room_member import RoomMemberHandler, RoomMemberMasterHandler
 from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
 from synapse.handlers.search import SearchHandler
+from synapse.handlers.send_email import SendEmailHandler
 from synapse.handlers.set_password import SetPasswordHandler
 from synapse.handlers.space_summary import SpaceSummaryHandler
 from synapse.handlers.sso import SsoHandler
@@ -550,6 +551,10 @@ class HomeServer(metaclass=abc.ABCMeta):
         return SearchHandler(self)
 
     @cache_in_self
+    def get_send_email_handler(self) -> SendEmailHandler:
+        return SendEmailHandler(self)
+
+    @cache_in_self
     def get_set_password_handler(self) -> SetPasswordHandler:
         return SetPasswordHandler(self)
 
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 3d98d3f5f8..0623da9aa1 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -14,7 +14,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-import random
 from abc import ABCMeta
 from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union
 
@@ -44,7 +43,6 @@ class SQLBaseStore(metaclass=ABCMeta):
         self._clock = hs.get_clock()
         self.database_engine = database.engine
         self.db_pool = database
-        self.rand = random.SystemRandom()
 
     def process_replication_rows(
         self,
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index c9346de316..a1f98b7e38 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -665,7 +665,7 @@ class DeviceWorkerStore(SQLBaseStore):
         cached_method_name="get_device_list_last_stream_id_for_remote",
         list_name="user_ids",
     )
-    async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
+    async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]):
         rows = await self.db_pool.simple_select_many_batch(
             table="device_lists_remote_extremeties",
             column="user_id",
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 398d6b6acb..9ba5778a88 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -473,7 +473,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         num_args=1,
     )
     async def _get_bare_e2e_cross_signing_keys_bulk(
-        self, user_ids: List[str]
+        self, user_ids: Iterable[str]
     ) -> Dict[str, Dict[str, dict]]:
         """Returns the cross-signing keys for a set of users.  The output of this
         function should be passed to _get_e2e_cross_signing_signatures_txn if
@@ -497,7 +497,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
     def _get_bare_e2e_cross_signing_keys_bulk_txn(
         self,
         txn: Connection,
-        user_ids: List[str],
+        user_ids: Iterable[str],
     ) -> Dict[str, Dict[str, dict]]:
         """Returns the cross-signing keys for a set of users.  The output of this
         function should be passed to _get_e2e_cross_signing_signatures_txn if
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 6e5ee557d2..e5c5cf8ff0 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -14,6 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+import random
 import re
 from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
 
@@ -997,7 +998,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         expiration_ts = now_ms + self._account_validity_period
 
         if use_delta:
-            expiration_ts = self.rand.randrange(
+            expiration_ts = random.randrange(
                 expiration_ts - self._account_validity_startup_job_max_delta,
                 expiration_ts,
             )
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index acf6b2fb64..1ecdd40c38 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -12,6 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Dict, Iterable
+
 from synapse.storage._base import SQLBaseStore
 from synapse.util.caches.descriptors import cached, cachedList
 
@@ -37,21 +39,16 @@ class UserErasureWorkerStore(SQLBaseStore):
         return bool(result)
 
     @cachedList(cached_method_name="is_user_erased", list_name="user_ids")
-    async def are_users_erased(self, user_ids):
+    async def are_users_erased(self, user_ids: Iterable[str]) -> Dict[str, bool]:
         """
         Checks which users in a list have requested erasure
 
         Args:
-            user_ids (iterable[str]): full user id to check
+            user_ids: full user ids to check
 
         Returns:
-            dict[str, bool]:
-                for each user, whether the user has requested erasure.
+            for each user, whether the user has requested erasure.
         """
-        # this serves the dual purpose of (a) making sure we can do len and
-        # iterate it multiple times, and (b) avoiding duplicates.
-        user_ids = tuple(set(user_ids))
-
         rows = await self.db_pool.simple_select_many_batch(
             table="erased_users",
             column="user_id",
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index ac4a078b26..3a4d027095 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -322,8 +322,8 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
 class DeferredCacheListDescriptor(_CacheDescriptorBase):
     """Wraps an existing cache to support bulk fetching of keys.
 
-    Given a list of keys it looks in the cache to find any hits, then passes
-    the list of missing keys to the wrapped function.
+    Given an iterable of keys it looks in the cache to find any hits, then passes
+    the tuple of missing keys to the wrapped function.
 
     Once wrapped, the function returns a Deferred which resolves to the list
     of results.
@@ -437,7 +437,9 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                     return f
 
                 args_to_call = dict(arg_dict)
-                args_to_call[self.list_name] = list(missing)
+                # copy the missing set before sending it to the callee, to guard against
+                # modification.
+                args_to_call[self.list_name] = tuple(missing)
 
                 cached_defers.append(
                     defer.maybeDeferred(
@@ -522,14 +524,14 @@ def cachedList(
 
     Used to do batch lookups for an already created cache. A single argument
     is specified as a list that is iterated through to lookup keys in the
-    original cache. A new list consisting of the keys that weren't in the cache
-    get passed to the original function, the result of which is stored in the
+    original cache. A new tuple consisting of the (deduplicated) keys that weren't in
+    the cache gets passed to the original function, the result of which is stored in the
     cache.
 
     Args:
         cached_method_name: The name of the single-item lookup method.
             This is only used to find the cache to use.
-        list_name: The name of the argument that is the list to use to
+        list_name: The name of the argument that is the iterable to use to
             do batch lookups in the cache.
         num_args: Number of arguments to use as the key in the cache
             (including list_name). Defaults to all named parameters.
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 4f25cd1d26..f029432191 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -13,8 +13,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import itertools
-import random
 import re
+import secrets
 import string
 from collections.abc import Iterable
 from typing import Optional, Tuple
@@ -35,26 +35,27 @@ CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$")
 #
 MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
 
-# random_string and random_string_with_symbols are used for a range of things,
-# some cryptographically important, some less so. We use SystemRandom to make sure
-# we get cryptographically-secure randoms.
-rand = random.SystemRandom()
-
 
 def random_string(length: int) -> str:
-    return "".join(rand.choice(string.ascii_letters) for _ in range(length))
+    """Generate a cryptographically secure string of random letters.
+
+    Drawn from the characters: `a-z` and `A-Z`
+    """
+    return "".join(secrets.choice(string.ascii_letters) for _ in range(length))
 
 
 def random_string_with_symbols(length: int) -> str:
-    return "".join(rand.choice(_string_with_symbols) for _ in range(length))
+    """Generate a cryptographically secure string of random letters/numbers/symbols.
+
+    Drawn from the characters: `a-z`, `A-Z`, `0-9`, and `.,;:^&*-_+=#~@`
+    """
+    return "".join(secrets.choice(_string_with_symbols) for _ in range(length))
 
 
 def is_ascii(s: bytes) -> bool:
     try:
         s.decode("ascii").encode("ascii")
-    except UnicodeDecodeError:
-        return False
-    except UnicodeEncodeError:
+    except UnicodeError:
         return False
     return True