summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/__init__.py59
-rw-r--r--synapse/util/async_helpers.py32
-rw-r--r--synapse/util/caches/__init__.py32
-rw-r--r--synapse/util/caches/deferred_cache.py11
-rw-r--r--synapse/util/caches/descriptors.py67
-rw-r--r--synapse/util/caches/expiringcache.py12
-rw-r--r--synapse/util/caches/lrucache.py42
-rw-r--r--synapse/util/distributor.py11
-rw-r--r--synapse/util/gai_resolver.py75
-rw-r--r--synapse/util/linked_list.py4
-rw-r--r--synapse/util/metrics.py12
-rw-r--r--synapse/util/stringutils.py21
-rw-r--r--synapse/util/versionstring.py82
13 files changed, 235 insertions, 225 deletions
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 95f23e27b6..f157132210 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -14,9 +14,8 @@
 
 import json
 import logging
-import re
 import typing
-from typing import Any, Callable, Dict, Generator, Optional, Pattern
+from typing import Any, Callable, Dict, Generator, Optional
 
 import attr
 from frozendict import frozendict
@@ -35,9 +34,6 @@ if typing.TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-_WILDCARD_RUN = re.compile(r"([\?\*]+)")
-
-
 def _reject_invalid_json(val: Any) -> None:
     """Do not allow Infinity, -Infinity, or NaN values in JSON."""
     raise ValueError("Invalid JSON value: '%s'" % val)
@@ -185,56 +181,3 @@ def log_failure(
     if not consumeErrors:
         return failure
     return None
-
-
-def glob_to_regex(glob: str, word_boundary: bool = False) -> Pattern:
-    """Converts a glob to a compiled regex object.
-
-    Args:
-        glob: pattern to match
-        word_boundary: If True, the pattern will be allowed to match at word boundaries
-           anywhere in the string. Otherwise, the pattern is anchored at the start and
-           end of the string.
-
-    Returns:
-        compiled regex pattern
-    """
-
-    # Patterns with wildcards must be simplified to avoid performance cliffs
-    # - The glob `?**?**?` is equivalent to the glob `???*`
-    # - The glob `???*` is equivalent to the regex `.{3,}`
-    chunks = []
-    for chunk in _WILDCARD_RUN.split(glob):
-        # No wildcards? re.escape()
-        if not _WILDCARD_RUN.match(chunk):
-            chunks.append(re.escape(chunk))
-            continue
-
-        # Wildcards? Simplify.
-        qmarks = chunk.count("?")
-        if "*" in chunk:
-            chunks.append(".{%d,}" % qmarks)
-        else:
-            chunks.append(".{%d}" % qmarks)
-
-    res = "".join(chunks)
-
-    if word_boundary:
-        res = re_word_boundary(res)
-    else:
-        # \A anchors at start of string, \Z at end of string
-        res = r"\A" + res + r"\Z"
-
-    return re.compile(res, re.IGNORECASE)
-
-
-def re_word_boundary(r: str) -> str:
-    """
-    Adds word boundary characters to the start and end of an
-    expression to require that the match occur as a whole word,
-    but do so respecting the fact that strings starting or ending
-    with non-word characters will change word boundaries.
-    """
-    # we can't use \b as it chokes on unicode. however \W seems to be okay
-    # as shorthand for [^0-9A-Za-z_].
-    return r"(^|\W)%s(\W|$)" % (r,)
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 561b962e14..20ce294209 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -27,6 +27,7 @@ from typing import (
     Generic,
     Hashable,
     Iterable,
+    Iterator,
     Optional,
     Set,
     TypeVar,
@@ -40,7 +41,6 @@ from typing_extensions import ContextManager
 from twisted.internet import defer
 from twisted.internet.defer import CancelledError
 from twisted.internet.interfaces import IReactorTime
-from twisted.python import failure
 from twisted.python.failure import Failure
 
 from synapse.logging.context import (
@@ -78,7 +78,7 @@ class ObservableDeferred(Generic[_T]):
         object.__setattr__(self, "_result", None)
         object.__setattr__(self, "_observers", [])
 
-        def callback(r):
+        def callback(r: _T) -> _T:
             object.__setattr__(self, "_result", (True, r))
 
             # once we have set _result, no more entries will be added to _observers,
@@ -98,7 +98,7 @@ class ObservableDeferred(Generic[_T]):
                     )
             return r
 
-        def errback(f):
+        def errback(f: Failure) -> Optional[Failure]:
             object.__setattr__(self, "_result", (False, f))
 
             # once we have set _result, no more entries will be added to _observers,
@@ -109,7 +109,7 @@ class ObservableDeferred(Generic[_T]):
             for observer in observers:
                 # This is a little bit of magic to correctly propagate stack
                 # traces when we `await` on one of the observer deferreds.
-                f.value.__failure__ = f
+                f.value.__failure__ = f  # type: ignore[union-attr]
                 try:
                     observer.errback(f)
                 except Exception as e:
@@ -314,7 +314,7 @@ class Linearizer:
         # will release the lock.
 
         @contextmanager
-        def _ctx_manager(_):
+        def _ctx_manager(_: None) -> Iterator[None]:
             try:
                 yield
             finally:
@@ -355,7 +355,7 @@ class Linearizer:
         new_defer = make_deferred_yieldable(defer.Deferred())
         entry.deferreds[new_defer] = 1
 
-        def cb(_r):
+        def cb(_r: None) -> "defer.Deferred[None]":
             logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
             entry.count += 1
 
@@ -371,7 +371,7 @@ class Linearizer:
             # code must be synchronous, so this is the only sensible place.)
             return self._clock.sleep(0)
 
-        def eb(e):
+        def eb(e: Failure) -> Failure:
             logger.info("defer %r got err %r", new_defer, e)
             if isinstance(e, CancelledError):
                 logger.debug(
@@ -435,7 +435,7 @@ class ReadWriteLock:
             await make_deferred_yieldable(curr_writer)
 
         @contextmanager
-        def _ctx_manager():
+        def _ctx_manager() -> Iterator[None]:
             try:
                 yield
             finally:
@@ -464,7 +464,7 @@ class ReadWriteLock:
         await make_deferred_yieldable(defer.gatherResults(to_wait_on))
 
         @contextmanager
-        def _ctx_manager():
+        def _ctx_manager() -> Iterator[None]:
             try:
                 yield
             finally:
@@ -524,7 +524,7 @@ def timeout_deferred(
 
     delayed_call = reactor.callLater(timeout, time_it_out)
 
-    def convert_cancelled(value: failure.Failure):
+    def convert_cancelled(value: Failure) -> Failure:
         # if the original deferred was cancelled, and our timeout has fired, then
         # the reason it was cancelled was due to our timeout. Turn the CancelledError
         # into a TimeoutError.
@@ -534,7 +534,7 @@ def timeout_deferred(
 
     deferred.addErrback(convert_cancelled)
 
-    def cancel_timeout(result):
+    def cancel_timeout(result: _T) -> _T:
         # stop the pending call to cancel the deferred if it's been fired
         if delayed_call.active():
             delayed_call.cancel()
@@ -542,11 +542,11 @@ def timeout_deferred(
 
     deferred.addBoth(cancel_timeout)
 
-    def success_cb(val):
+    def success_cb(val: _T) -> None:
         if not new_d.called:
             new_d.callback(val)
 
-    def failure_cb(val):
+    def failure_cb(val: Failure) -> None:
         if not new_d.called:
             new_d.errback(val)
 
@@ -557,13 +557,13 @@ def timeout_deferred(
 
 # This class can't be generic because it uses slots with attrs.
 # See: https://github.com/python-attrs/attrs/issues/313
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class DoneAwaitable:  # should be: Generic[R]
     """Simple awaitable that returns the provided value."""
 
-    value = attr.ib(type=Any)  # should be: R
+    value: Any  # should be: R
 
-    def __await__(self):
+    def __await__(self) -> Any:
         return self
 
     def __iter__(self) -> "DoneAwaitable":
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index df4d61e4b6..15debd6c46 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -17,7 +17,7 @@ import logging
 import typing
 from enum import Enum, auto
 from sys import intern
-from typing import Callable, Dict, Optional, Sized
+from typing import Any, Callable, Dict, List, Optional, Sized
 
 import attr
 from prometheus_client.core import Gauge
@@ -58,20 +58,20 @@ class EvictionReason(Enum):
     time = auto()
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
 class CacheMetric:
 
-    _cache = attr.ib()
-    _cache_type = attr.ib(type=str)
-    _cache_name = attr.ib(type=str)
-    _collect_callback = attr.ib(type=Optional[Callable])
+    _cache: Sized
+    _cache_type: str
+    _cache_name: str
+    _collect_callback: Optional[Callable]
 
-    hits = attr.ib(default=0)
-    misses = attr.ib(default=0)
+    hits: int = 0
+    misses: int = 0
     eviction_size_by_reason: typing.Counter[EvictionReason] = attr.ib(
         factory=collections.Counter
     )
-    memory_usage = attr.ib(default=None)
+    memory_usage: Optional[int] = None
 
     def inc_hits(self) -> None:
         self.hits += 1
@@ -89,13 +89,14 @@ class CacheMetric:
         self.memory_usage += memory
 
     def dec_memory_usage(self, memory: int) -> None:
+        assert self.memory_usage is not None
         self.memory_usage -= memory
 
     def clear_memory_usage(self) -> None:
         if self.memory_usage is not None:
             self.memory_usage = 0
 
-    def describe(self):
+    def describe(self) -> List[str]:
         return []
 
     def collect(self) -> None:
@@ -118,8 +119,9 @@ class CacheMetric:
                         self.eviction_size_by_reason[reason]
                     )
                 cache_total.labels(self._cache_name).set(self.hits + self.misses)
-                if getattr(self._cache, "max_size", None):
-                    cache_max_size.labels(self._cache_name).set(self._cache.max_size)
+                max_size = getattr(self._cache, "max_size", None)
+                if max_size:
+                    cache_max_size.labels(self._cache_name).set(max_size)
 
                 if TRACK_MEMORY_USAGE:
                     # self.memory_usage can be None if nothing has been inserted
@@ -193,7 +195,7 @@ KNOWN_KEYS = {
 }
 
 
-def intern_string(string):
+def intern_string(string: Optional[str]) -> Optional[str]:
     """Takes a (potentially) unicode string and interns it if it's ascii"""
     if string is None:
         return None
@@ -204,7 +206,7 @@ def intern_string(string):
         return string
 
 
-def intern_dict(dictionary):
+def intern_dict(dictionary: Dict[str, Any]) -> Dict[str, Any]:
     """Takes a dictionary and interns well known keys and their values"""
     return {
         KNOWN_KEYS.get(key, key): _intern_known_values(key, value)
@@ -212,7 +214,7 @@ def intern_dict(dictionary):
     }
 
 
-def _intern_known_values(key, value):
+def _intern_known_values(key: str, value: Any) -> Any:
     intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key")
 
     if key in intern_keys:
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index da502aec11..377c9a282a 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -22,6 +22,7 @@ from typing import (
     Iterable,
     MutableMapping,
     Optional,
+    Sized,
     TypeVar,
     Union,
     cast,
@@ -104,7 +105,13 @@ class DeferredCache(Generic[KT, VT]):
             max_size=max_entries,
             cache_name=name,
             cache_type=cache_type,
-            size_callback=(lambda d: len(d) or 1) if iterable else None,
+            size_callback=(
+                (lambda d: len(cast(Sized, d)) or 1)
+                # Argument 1 to "len" has incompatible type "VT"; expected "Sized"
+                # We trust that `VT` is `Sized` when `iterable` is `True`
+                if iterable
+                else None
+            ),
             metrics_collection_callback=metrics_cb,
             apply_cache_factor_from_config=apply_cache_factor_from_config,
             prune_unread_entries=prune_unread_entries,
@@ -289,7 +296,7 @@ class DeferredCache(Generic[KT, VT]):
         callbacks = [callback] if callback else []
         self.cache.set(key, value, callbacks=callbacks)
 
-    def invalidate(self, key) -> None:
+    def invalidate(self, key: KT) -> None:
         """Delete a key, or tree of entries
 
         If the cache is backed by a regular dict, then "key" must be of
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index b9dcca17f1..375cd443f1 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -19,12 +19,15 @@ import logging
 from typing import (
     Any,
     Callable,
+    Dict,
     Generic,
+    Hashable,
     Iterable,
     Mapping,
     Optional,
     Sequence,
     Tuple,
+    Type,
     TypeVar,
     Union,
     cast,
@@ -32,6 +35,7 @@ from typing import (
 from weakref import WeakValueDictionary
 
 from twisted.internet import defer
+from twisted.python.failure import Failure
 
 from synapse.logging.context import make_deferred_yieldable, preserve_fn
 from synapse.util import unwrapFirstError
@@ -60,7 +64,12 @@ class _CachedFunction(Generic[F]):
 
 
 class _CacheDescriptorBase:
-    def __init__(self, orig: Callable[..., Any], num_args, cache_context=False):
+    def __init__(
+        self,
+        orig: Callable[..., Any],
+        num_args: Optional[int],
+        cache_context: bool = False,
+    ):
         self.orig = orig
 
         arg_spec = inspect.getfullargspec(orig)
@@ -172,14 +181,14 @@ class LruCacheDescriptor(_CacheDescriptorBase):
 
     def __init__(
         self,
-        orig,
+        orig: Callable[..., Any],
         max_entries: int = 1000,
         cache_context: bool = False,
     ):
         super().__init__(orig, num_args=None, cache_context=cache_context)
         self.max_entries = max_entries
 
-    def __get__(self, obj, owner):
+    def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
         cache: LruCache[CacheKey, Any] = LruCache(
             cache_name=self.orig.__name__,
             max_size=self.max_entries,
@@ -189,7 +198,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
         sentinel = LruCacheDescriptor._Sentinel.sentinel
 
         @functools.wraps(self.orig)
-        def _wrapped(*args, **kwargs):
+        def _wrapped(*args: Any, **kwargs: Any) -> Any:
             invalidate_callback = kwargs.pop("on_invalidate", None)
             callbacks = (invalidate_callback,) if invalidate_callback else ()
 
@@ -245,19 +254,19 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
             return r1 + r2
 
     Args:
-        num_args (int): number of positional arguments (excluding ``self`` and
+        num_args: number of positional arguments (excluding ``self`` and
             ``cache_context``) to use as cache keys. Defaults to all named
             args of the function.
     """
 
     def __init__(
         self,
-        orig,
-        max_entries=1000,
-        num_args=None,
-        tree=False,
-        cache_context=False,
-        iterable=False,
+        orig: Callable[..., Any],
+        max_entries: int = 1000,
+        num_args: Optional[int] = None,
+        tree: bool = False,
+        cache_context: bool = False,
+        iterable: bool = False,
         prune_unread_entries: bool = True,
     ):
         super().__init__(orig, num_args=num_args, cache_context=cache_context)
@@ -272,7 +281,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         self.iterable = iterable
         self.prune_unread_entries = prune_unread_entries
 
-    def __get__(self, obj, owner):
+    def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
         cache: DeferredCache[CacheKey, Any] = DeferredCache(
             name=self.orig.__name__,
             max_entries=self.max_entries,
@@ -284,7 +293,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         get_cache_key = self.cache_key_builder
 
         @functools.wraps(self.orig)
-        def _wrapped(*args, **kwargs):
+        def _wrapped(*args: Any, **kwargs: Any) -> Any:
             # If we're passed a cache_context then we'll want to call its invalidate()
             # whenever we are invalidated
             invalidate_callback = kwargs.pop("on_invalidate", None)
@@ -335,13 +344,19 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
     of results.
     """
 
-    def __init__(self, orig, cached_method_name, list_name, num_args=None):
+    def __init__(
+        self,
+        orig: Callable[..., Any],
+        cached_method_name: str,
+        list_name: str,
+        num_args: Optional[int] = None,
+    ):
         """
         Args:
-            orig (function)
-            cached_method_name (str): The name of the cached method.
-            list_name (str): Name of the argument which is the bulk lookup list
-            num_args (int): number of positional arguments (excluding ``self``,
+            orig
+            cached_method_name: The name of the cached method.
+            list_name: Name of the argument which is the bulk lookup list
+            num_args: number of positional arguments (excluding ``self``,
                 but including list_name) to use as cache keys. Defaults to all
                 named args of the function.
         """
@@ -360,13 +375,15 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                 % (self.list_name, cached_method_name)
             )
 
-    def __get__(self, obj, objtype=None):
+    def __get__(
+        self, obj: Optional[Any], objtype: Optional[Type] = None
+    ) -> Callable[..., Any]:
         cached_method = getattr(obj, self.cached_method_name)
         cache: DeferredCache[CacheKey, Any] = cached_method.cache
         num_args = cached_method.num_args
 
         @functools.wraps(self.orig)
-        def wrapped(*args, **kwargs):
+        def wrapped(*args: Any, **kwargs: Any) -> Any:
             # If we're passed a cache_context then we'll want to call its
             # invalidate() whenever we are invalidated
             invalidate_callback = kwargs.pop("on_invalidate", None)
@@ -377,7 +394,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
 
             results = {}
 
-            def update_results_dict(res, arg):
+            def update_results_dict(res: Any, arg: Hashable) -> None:
                 results[arg] = res
 
             # list of deferreds to wait for
@@ -389,13 +406,13 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
             # otherwise a tuple is used.
             if num_args == 1:
 
-                def arg_to_cache_key(arg):
+                def arg_to_cache_key(arg: Hashable) -> Hashable:
                     return arg
 
             else:
                 keylist = list(keyargs)
 
-                def arg_to_cache_key(arg):
+                def arg_to_cache_key(arg: Hashable) -> Hashable:
                     keylist[self.list_pos] = arg
                     return tuple(keylist)
 
@@ -421,7 +438,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                     key = arg_to_cache_key(arg)
                     cache.set(key, deferred, callback=invalidate_callback)
 
-                def complete_all(res):
+                def complete_all(res: Dict[Hashable, Any]) -> None:
                     # the wrapped function has completed. It returns a
                     # a dict. We can now resolve the observable deferreds in
                     # the cache and update our own result map.
@@ -430,7 +447,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                         deferreds_map[e].callback(val)
                         results[e] = val
 
-                def errback(f):
+                def errback(f: Failure) -> Failure:
                     # the wrapped function has failed. Invalidate any cache
                     # entries we're supposed to be populating, and fail
                     # their deferreds.
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index c3f72aa06d..67ee4c693b 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -19,6 +19,8 @@ from typing import Any, Generic, Optional, TypeVar, Union, overload
 import attr
 from typing_extensions import Literal
 
+from twisted.internet import defer
+
 from synapse.config import cache as cache_config
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.util import Clock
@@ -81,7 +83,7 @@ class ExpiringCache(Generic[KT, VT]):
             # Don't bother starting the loop if things never expire
             return
 
-        def f():
+        def f() -> "defer.Deferred[None]":
             return run_as_background_process(
                 "prune_cache_%s" % self._cache_name, self._prune_cache
             )
@@ -157,7 +159,7 @@ class ExpiringCache(Generic[KT, VT]):
             self[key] = value
             return value
 
-    def _prune_cache(self) -> None:
+    async def _prune_cache(self) -> None:
         if not self._expiry_ms:
             # zero expiry time means don't expire. This should never get called
             # since we have this check in start too.
@@ -210,7 +212,7 @@ class ExpiringCache(Generic[KT, VT]):
         return False
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
 class _CacheEntry:
-    time = attr.ib(type=int)
-    value = attr.ib()
+    time: int
+    value: Any
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index a0a7a9de32..eb96f7e665 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -15,14 +15,15 @@
 import logging
 import threading
 import weakref
+from enum import Enum
 from functools import wraps
 from typing import (
     TYPE_CHECKING,
     Any,
     Callable,
     Collection,
+    Dict,
     Generic,
-    Iterable,
     List,
     Optional,
     Type,
@@ -190,7 +191,7 @@ class _Node(Generic[KT, VT]):
         root: "ListNode[_Node]",
         key: KT,
         value: VT,
-        cache: "weakref.ReferenceType[LruCache]",
+        cache: "weakref.ReferenceType[LruCache[KT, VT]]",
         clock: Clock,
         callbacks: Collection[Callable[[], None]] = (),
         prune_unread_entries: bool = True,
@@ -270,7 +271,10 @@ class _Node(Generic[KT, VT]):
         removed from all lists.
         """
         cache = self._cache()
-        if not cache or not cache.pop(self.key, None):
+        if (
+            cache is None
+            or cache.pop(self.key, _Sentinel.sentinel) is _Sentinel.sentinel
+        ):
             # `cache.pop` should call `drop_from_lists()`, unless this Node had
             # already been removed from the cache.
             self.drop_from_lists()
@@ -290,6 +294,12 @@ class _Node(Generic[KT, VT]):
             self._global_list_node.update_last_access(clock)
 
 
+class _Sentinel(Enum):
+    # defining a sentinel in this way allows mypy to correctly handle the
+    # type of a dictionary lookup.
+    sentinel = object()
+
+
 class LruCache(Generic[KT, VT]):
     """
     Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
@@ -302,7 +312,7 @@ class LruCache(Generic[KT, VT]):
         max_size: int,
         cache_name: Optional[str] = None,
         cache_type: Type[Union[dict, TreeCache]] = dict,
-        size_callback: Optional[Callable] = None,
+        size_callback: Optional[Callable[[VT], int]] = None,
         metrics_collection_callback: Optional[Callable[[], None]] = None,
         apply_cache_factor_from_config: bool = True,
         clock: Optional[Clock] = None,
@@ -339,7 +349,7 @@ class LruCache(Generic[KT, VT]):
         else:
             real_clock = clock
 
-        cache = cache_type()
+        cache: Union[Dict[KT, _Node[KT, VT]], TreeCache] = cache_type()
         self.cache = cache  # Used for introspection.
         self.apply_cache_factor_from_config = apply_cache_factor_from_config
 
@@ -374,7 +384,7 @@ class LruCache(Generic[KT, VT]):
         # creating more each time we create a `_Node`.
         weak_ref_to_self = weakref.ref(self)
 
-        list_root = ListNode[_Node].create_root_node()
+        list_root = ListNode[_Node[KT, VT]].create_root_node()
 
         lock = threading.Lock()
 
@@ -422,7 +432,7 @@ class LruCache(Generic[KT, VT]):
         def add_node(
             key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
         ) -> None:
-            node = _Node(
+            node: _Node[KT, VT] = _Node(
                 list_root,
                 key,
                 value,
@@ -439,10 +449,10 @@ class LruCache(Generic[KT, VT]):
             if caches.TRACK_MEMORY_USAGE and metrics:
                 metrics.inc_memory_usage(node.memory)
 
-        def move_node_to_front(node: _Node) -> None:
+        def move_node_to_front(node: _Node[KT, VT]) -> None:
             node.move_to_front(real_clock, list_root)
 
-        def delete_node(node: _Node) -> int:
+        def delete_node(node: _Node[KT, VT]) -> int:
             node.drop_from_lists()
 
             deleted_len = 1
@@ -496,7 +506,7 @@ class LruCache(Generic[KT, VT]):
 
         @synchronized
         def cache_set(
-            key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = ()
+            key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
         ) -> None:
             node = cache.get(key, None)
             if node is not None:
@@ -590,8 +600,6 @@ class LruCache(Generic[KT, VT]):
         def cache_contains(key: KT) -> bool:
             return key in cache
 
-        self.sentinel = object()
-
         # make sure that we clear out any excess entries after we get resized.
         self._on_resize = evict
 
@@ -608,18 +616,18 @@ class LruCache(Generic[KT, VT]):
         self.clear = cache_clear
 
     def __getitem__(self, key: KT) -> VT:
-        result = self.get(key, self.sentinel)
-        if result is self.sentinel:
+        result = self.get(key, _Sentinel.sentinel)
+        if result is _Sentinel.sentinel:
             raise KeyError()
         else:
-            return cast(VT, result)
+            return result
 
     def __setitem__(self, key: KT, value: VT) -> None:
         self.set(key, value)
 
     def __delitem__(self, key: KT, value: VT) -> None:
-        result = self.pop(key, self.sentinel)
-        if result is self.sentinel:
+        result = self.pop(key, _Sentinel.sentinel)
+        if result is _Sentinel.sentinel:
             raise KeyError()
 
     def __len__(self) -> int:
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 31097d6439..91837655f8 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -18,12 +18,13 @@ from twisted.internet import defer
 
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import UserID
 from synapse.util.async_helpers import maybe_awaitable
 
 logger = logging.getLogger(__name__)
 
 
-def user_left_room(distributor, user, room_id):
+def user_left_room(distributor: "Distributor", user: UserID, room_id: str) -> None:
     distributor.fire("user_left_room", user=user, room_id=room_id)
 
 
@@ -63,7 +64,7 @@ class Distributor:
                 self.pre_registration[name] = []
             self.pre_registration[name].append(observer)
 
-    def fire(self, name: str, *args, **kwargs) -> None:
+    def fire(self, name: str, *args: Any, **kwargs: Any) -> None:
         """Dispatches the given signal to the registered observers.
 
         Runs the observers as a background process. Does not return a deferred.
@@ -95,7 +96,7 @@ class Signal:
         Each observer callable may return a Deferred."""
         self.observers.append(observer)
 
-    def fire(self, *args, **kwargs) -> "defer.Deferred[List[Any]]":
+    def fire(self, *args: Any, **kwargs: Any) -> "defer.Deferred[List[Any]]":
         """Invokes every callable in the observer list, passing in the args and
         kwargs. Exceptions thrown by observers are logged but ignored. It is
         not an error to fire a signal with no observers.
@@ -103,7 +104,7 @@ class Signal:
         Returns a Deferred that will complete when all the observers have
         completed."""
 
-        async def do(observer):
+        async def do(observer: Callable[..., Any]) -> Any:
             try:
                 return await maybe_awaitable(observer(*args, **kwargs))
             except Exception as e:
@@ -120,5 +121,5 @@ class Signal:
             defer.gatherResults(deferreds, consumeErrors=True)
         )
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "<Signal name=%r>" % (self.name,)
diff --git a/synapse/util/gai_resolver.py b/synapse/util/gai_resolver.py
index a447ce4e55..214eb17fbc 100644
--- a/synapse/util/gai_resolver.py
+++ b/synapse/util/gai_resolver.py
@@ -3,23 +3,52 @@
 # We copy it here as we need to instantiate `GAIResolver` manually, but it is a
 # private class.
 
-
 from socket import (
     AF_INET,
     AF_INET6,
     AF_UNSPEC,
     SOCK_DGRAM,
     SOCK_STREAM,
+    AddressFamily,
+    SocketKind,
     gaierror,
     getaddrinfo,
 )
+from typing import (
+    TYPE_CHECKING,
+    Callable,
+    List,
+    NoReturn,
+    Optional,
+    Sequence,
+    Tuple,
+    Type,
+    Union,
+)
 
 from zope.interface import implementer
 
 from twisted.internet.address import IPv4Address, IPv6Address
-from twisted.internet.interfaces import IHostnameResolver, IHostResolution
+from twisted.internet.interfaces import (
+    IAddress,
+    IHostnameResolver,
+    IHostResolution,
+    IReactorThreads,
+    IResolutionReceiver,
+)
 from twisted.internet.threads import deferToThreadPool
 
+if TYPE_CHECKING:
+    # The types below are copied from
+    # https://github.com/twisted/twisted/blob/release-21.2.0-10091/src/twisted/internet/interfaces.py
+    # so that the type hints can match the interfaces.
+    from twisted.python.runtime import platform
+
+    if platform.supportsThreads():
+        from twisted.python.threadpool import ThreadPool
+    else:
+        ThreadPool = object  # type: ignore[misc, assignment]
+
 
 @implementer(IHostResolution)
 class HostResolution:
@@ -27,13 +56,13 @@ class HostResolution:
     The in-progress resolution of a given hostname.
     """
 
-    def __init__(self, name):
+    def __init__(self, name: str):
         """
         Create a L{HostResolution} with the given name.
         """
         self.name = name
 
-    def cancel(self):
+    def cancel(self) -> NoReturn:
         # IHostResolution.cancel
         raise NotImplementedError()
 
@@ -62,6 +91,17 @@ _socktypeToType = {
 }
 
 
+_GETADDRINFO_RESULT = List[
+    Tuple[
+        AddressFamily,
+        SocketKind,
+        int,
+        str,
+        Union[Tuple[str, int], Tuple[str, int, int, int]],
+    ]
+]
+
+
 @implementer(IHostnameResolver)
 class GAIResolver:
     """
@@ -69,7 +109,12 @@ class GAIResolver:
     L{getaddrinfo} in a thread.
     """
 
-    def __init__(self, reactor, getThreadPool=None, getaddrinfo=getaddrinfo):
+    def __init__(
+        self,
+        reactor: IReactorThreads,
+        getThreadPool: Optional[Callable[[], "ThreadPool"]] = None,
+        getaddrinfo: Callable[[str, int, int, int], _GETADDRINFO_RESULT] = getaddrinfo,
+    ):
         """
         Create a L{GAIResolver}.
         @param reactor: the reactor to schedule result-delivery on
@@ -89,14 +134,16 @@ class GAIResolver:
         )
         self._getaddrinfo = getaddrinfo
 
-    def resolveHostName(
+    # The types on IHostnameResolver is incorrect in Twisted, see
+    # https://twistedmatrix.com/trac/ticket/10276
+    def resolveHostName(  # type: ignore[override]
         self,
-        resolutionReceiver,
-        hostName,
-        portNumber=0,
-        addressTypes=None,
-        transportSemantics="TCP",
-    ):
+        resolutionReceiver: IResolutionReceiver,
+        hostName: str,
+        portNumber: int = 0,
+        addressTypes: Optional[Sequence[Type[IAddress]]] = None,
+        transportSemantics: str = "TCP",
+    ) -> IHostResolution:
         """
         See L{IHostnameResolver.resolveHostName}
         @param resolutionReceiver: see interface
@@ -112,7 +159,7 @@ class GAIResolver:
         ]
         socketType = _transportToSocket[transportSemantics]
 
-        def get():
+        def get() -> _GETADDRINFO_RESULT:
             try:
                 return self._getaddrinfo(
                     hostName, portNumber, addressFamily, socketType
@@ -125,7 +172,7 @@ class GAIResolver:
         resolutionReceiver.resolutionBegan(resolution)
 
         @d.addCallback
-        def deliverResults(result):
+        def deliverResults(result: _GETADDRINFO_RESULT) -> None:
             for family, socktype, _proto, _cannoname, sockaddr in result:
                 addrType = _afToType[family]
                 resolutionReceiver.addressResolved(
diff --git a/synapse/util/linked_list.py b/synapse/util/linked_list.py
index 9f4be757ba..8efbf061aa 100644
--- a/synapse/util/linked_list.py
+++ b/synapse/util/linked_list.py
@@ -84,7 +84,7 @@ class ListNode(Generic[P]):
         # immediately rather than at the next GC.
         self.cache_entry = None
 
-    def move_after(self, node: "ListNode") -> None:
+    def move_after(self, node: "ListNode[P]") -> None:
         """Move this node from its current location in the list to after the
         given node.
         """
@@ -122,7 +122,7 @@ class ListNode(Generic[P]):
         self.prev_node = None
         self.next_node = None
 
-    def _refs_insert_after(self, node: "ListNode") -> None:
+    def _refs_insert_after(self, node: "ListNode[P]") -> None:
         """Internal method to insert the node after the given node."""
 
         # This method should only be called when we're not already in the list.
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 1e784b3f1f..98ee49af6e 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -56,14 +56,22 @@ block_db_sched_duration = Counter(
     "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"]
 )
 
+
+# This is dynamically created in InFlightGauge.__init__.
+class _InFlightMetric(Protocol):
+    real_time_max: float
+    real_time_sum: float
+
+
 # Tracks the number of blocks currently active
-in_flight = InFlightGauge(
+in_flight: InFlightGauge[_InFlightMetric] = InFlightGauge(
     "synapse_util_metrics_block_in_flight",
     "",
     labels=["block_name"],
     sub_metrics=["real_time_max", "real_time_sum"],
 )
 
+
 T = TypeVar("T", bound=Callable[..., Any])
 
 
@@ -180,7 +188,7 @@ class Measure:
         """
         return self._logging_context.get_resource_usage()
 
-    def _update_in_flight(self, metrics) -> None:
+    def _update_in_flight(self, metrics: _InFlightMetric) -> None:
         """Gets called when processing in flight metrics"""
         assert self.start is not None
         duration = self.clock.time() - self.start
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index f029432191..ea1032b4fc 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -19,6 +19,8 @@ import string
 from collections.abc import Iterable
 from typing import Optional, Tuple
 
+from netaddr import valid_ipv6
+
 from synapse.api.errors import Codes, SynapseError
 
 _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
@@ -97,7 +99,10 @@ def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]:
         raise ValueError("Invalid server name '%s'" % server_name)
 
 
-VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z")
+# An approximation of the domain name syntax in RFC 1035, section 2.3.1.
+# NB: "\Z" is not equivalent to "$".
+#     The latter will match the position before a "\n" at the end of a string.
+VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z-]+(?:\\.[0-9a-zA-Z-]+)*\\Z")
 
 
 def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]]:
@@ -122,13 +127,15 @@ def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]
     if host[0] == "[":
         if host[-1] != "]":
             raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
-        return host, port
 
-    # otherwise it should only be alphanumerics.
-    if not VALID_HOST_REGEX.match(host):
-        raise ValueError(
-            "Server name '%s' contains invalid characters" % (server_name,)
-        )
+        # valid_ipv6 raises when given an empty string
+        ipv6_address = host[1:-1]
+        if not ipv6_address or not valid_ipv6(ipv6_address):
+            raise ValueError(
+                "Server name '%s' is not a valid IPv6 address" % (server_name,)
+            )
+    elif not VALID_HOST_REGEX.match(host):
+        raise ValueError("Server name '%s' has an invalid format" % (server_name,))
 
     return host, port
 
diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
index 899ee0adc8..c144ff62c1 100644
--- a/synapse/util/versionstring.py
+++ b/synapse/util/versionstring.py
@@ -1,4 +1,5 @@
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -29,10 +30,11 @@ def get_version_string(module: ModuleType) -> str:
     If called on a module not in a git checkout will return `__version__`.
 
     Args:
-        module (module)
+        module: The module to check the version of. Must declare a __version__
+            attribute.
 
     Returns:
-        str
+        The module version (as a string).
     """
 
     cached_version = version_cache.get(module)
@@ -44,71 +46,37 @@ def get_version_string(module: ModuleType) -> str:
     version_string = module.__version__  # type: ignore[attr-defined]
 
     try:
-        null = open(os.devnull, "w")
         cwd = os.path.dirname(os.path.abspath(module.__file__))
 
-        try:
-            git_branch = (
-                subprocess.check_output(
-                    ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=null, cwd=cwd
+        def _run_git_command(prefix: str, *params: str) -> str:
+            try:
+                result = (
+                    subprocess.check_output(
+                        ["git", *params], stderr=subprocess.DEVNULL, cwd=cwd
+                    )
+                    .strip()
+                    .decode("ascii")
                 )
-                .strip()
-                .decode("ascii")
-            )
-            git_branch = "b=" + git_branch
-        except (subprocess.CalledProcessError, FileNotFoundError):
-            # FileNotFoundError can arise when git is not installed
-            git_branch = ""
-
-        try:
-            git_tag = (
-                subprocess.check_output(
-                    ["git", "describe", "--exact-match"], stderr=null, cwd=cwd
-                )
-                .strip()
-                .decode("ascii")
-            )
-            git_tag = "t=" + git_tag
-        except (subprocess.CalledProcessError, FileNotFoundError):
-            git_tag = ""
-
-        try:
-            git_commit = (
-                subprocess.check_output(
-                    ["git", "rev-parse", "--short", "HEAD"], stderr=null, cwd=cwd
-                )
-                .strip()
-                .decode("ascii")
-            )
-        except (subprocess.CalledProcessError, FileNotFoundError):
-            git_commit = ""
-
-        try:
-            dirty_string = "-this_is_a_dirty_checkout"
-            is_dirty = (
-                subprocess.check_output(
-                    ["git", "describe", "--dirty=" + dirty_string], stderr=null, cwd=cwd
-                )
-                .strip()
-                .decode("ascii")
-                .endswith(dirty_string)
-            )
+                return prefix + result
+            except (subprocess.CalledProcessError, FileNotFoundError):
+                return ""
 
-            git_dirty = "dirty" if is_dirty else ""
-        except (subprocess.CalledProcessError, FileNotFoundError):
-            git_dirty = ""
+        git_branch = _run_git_command("b=", "rev-parse", "--abbrev-ref", "HEAD")
+        git_tag = _run_git_command("t=", "describe", "--exact-match")
+        git_commit = _run_git_command("", "rev-parse", "--short", "HEAD")
+
+        dirty_string = "-this_is_a_dirty_checkout"
+        is_dirty = _run_git_command("", "describe", "--dirty=" + dirty_string).endswith(
+            dirty_string
+        )
+        git_dirty = "dirty" if is_dirty else ""
 
         if git_branch or git_tag or git_commit or git_dirty:
             git_version = ",".join(
                 s for s in (git_branch, git_tag, git_commit, git_dirty) if s
             )
 
-            version_string = "%s (%s)" % (
-                # If the __version__ attribute doesn't exist, we'll have failed
-                # loudly above.
-                module.__version__,  # type: ignore[attr-defined]
-                git_version,
-            )
+            version_string = f"{version_string} ({git_version})"
     except Exception as e:
         logger.info("Failed to check for git repository: %s", e)