diff --git a/synapse/__init__.py b/synapse/__init__.py
index c9bc8fb9e9..419299bf01 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -48,7 +48,7 @@ try:
except ImportError:
pass
-__version__ = "1.30.1"
+__version__ = "1.31.0rc1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 8d0f6b7b31..26cb1bc657 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -563,6 +563,9 @@ class Auth:
Returns:
bool: False if no access_token was given, True otherwise.
"""
+ # This will always be set by the time Twisted calls us.
+ assert request.args is not None
+
query_params = request.args.get(b"access_token")
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
return bool(query_params) or bool(auth_headers)
@@ -579,6 +582,8 @@ class Auth:
MissingClientTokenError: If there isn't a single access_token in the
request
"""
+ # This will always be set by the time Twisted calls us.
+ assert request.args is not None
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
query_params = request.args.get(b"access_token")
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 43b1f1e94b..3912c8994c 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -21,8 +21,10 @@ import signal
import socket
import sys
import traceback
+import warnings
from typing import Awaitable, Callable, Iterable
+from cryptography.utils import CryptographyDeprecationWarning
from typing_extensions import NoReturn
from twisted.internet import defer, error, reactor
@@ -195,6 +197,25 @@ def listen_metrics(bind_addresses, port):
start_http_server(port, addr=host, registry=RegistryProxy)
+def listen_manhole(bind_addresses: Iterable[str], port: int, manhole_globals: dict):
+ # twisted.conch.manhole 21.1.0 uses "int_from_bytes", which produces a confusing
+ # warning. It's fixed by https://github.com/twisted/twisted/pull/1522), so
+ # suppress the warning for now.
+ warnings.filterwarnings(
+ action="ignore",
+ category=CryptographyDeprecationWarning,
+ message="int_from_bytes is deprecated",
+ )
+
+ from synapse.util.manhole import manhole
+
+ listen_tcp(
+ bind_addresses,
+ port,
+ manhole(username="matrix", password="rabbithole", globals=manhole_globals),
+ )
+
+
def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
"""
Create a TCP socket for a port and several addresses
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index edf52ddc32..b2d21acefd 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -147,7 +147,6 @@ from synapse.storage.databases.main.user_directory import UserDirectoryStore
from synapse.types import ReadReceipt
from synapse.util.async_helpers import Linearizer
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
logger = logging.getLogger("synapse.app.generic_worker")
@@ -640,12 +639,8 @@ class GenericWorkerServer(HomeServer):
if listener.type == "http":
self._listen_http(listener)
elif listener.type == "manhole":
- _base.listen_tcp(
- listener.bind_addresses,
- listener.port,
- manhole(
- username="matrix", password="rabbithole", globals={"hs": self}
- ),
+ _base.listen_manhole(
+ listener.bind_addresses, listener.port, manhole_globals={"hs": self}
)
elif listener.type == "metrics":
if not self.get_config().enable_metrics:
@@ -792,13 +787,6 @@ class FederationSenderHandler:
self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
- def on_start(self):
- # There may be some events that are persisted but haven't been sent,
- # so send them now.
- self.federation_sender.notify_new_events(
- self.store.get_room_max_stream_ordering()
- )
-
def wake_destination(self, server: str):
self.federation_sender.wake_destination(server)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 244657cb88..3bfe9d507f 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -67,7 +67,6 @@ from synapse.storage import DataStore
from synapse.storage.engines import IncorrectDatabaseSetup
from synapse.storage.prepare_database import UpgradeDatabaseException
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.manhole import manhole
from synapse.util.module_loader import load_module
from synapse.util.versionstring import get_version_string
@@ -288,12 +287,8 @@ class SynapseHomeServer(HomeServer):
if listener.type == "http":
self._listening_services.extend(self._listener_http(config, listener))
elif listener.type == "manhole":
- listen_tcp(
- listener.bind_addresses,
- listener.port,
- manhole(
- username="matrix", password="rabbithole", globals={"hs": self}
- ),
+ _base.listen_manhole(
+ listener.bind_addresses, listener.port, manhole_globals={"hs": self}
)
elif listener.type == "replication":
services = listen_tcp(
diff --git a/synapse/config/cache.py b/synapse/config/cache.py
index 8e03f14005..4e8abbf88a 100644
--- a/synapse/config/cache.py
+++ b/synapse/config/cache.py
@@ -24,7 +24,7 @@ from ._base import Config, ConfigError
_CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR"
# Map from canonicalised cache name to cache.
-_CACHES = {}
+_CACHES = {} # type: Dict[str, Callable[[float], None]]
# a lock on the contents of _CACHES
_CACHES_LOCK = threading.Lock()
@@ -59,7 +59,9 @@ def _canonicalise_cache_name(cache_name: str) -> str:
return cache_name.lower()
-def add_resizable_cache(cache_name: str, cache_resize_callback: Callable):
+def add_resizable_cache(
+ cache_name: str, cache_resize_callback: Callable[[float], None]
+):
"""Register a cache that's size can dynamically change
Args:
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index 747ab9a7fe..05733ec41d 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -79,6 +79,9 @@ class OIDCConfig(Config):
# Note that, if this is changed, users authenticating via that provider
# will no longer be recognised as the same user!
#
+ # (Use "oidc" here if you are migrating from an old "oidc_config"
+ # configuration.)
+ #
# idp_name: A user-facing name for this identity provider, which is used to
# offer the user a choice of login mechanisms.
#
@@ -247,37 +250,6 @@ class OIDCConfig(Config):
# attribute_requirements:
# - attribute: userGroup
# value: "synapseUsers"
-
- # For use with Keycloak
- #
- #- idp_id: keycloak
- # idp_name: Keycloak
- # issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name"
- # client_id: "synapse"
- # client_secret: "copy secret generated in Keycloak UI"
- # scopes: ["openid", "profile"]
- # attribute_requirements:
- # - attribute: groups
- # value: "admin"
-
- # For use with Github
- #
- #- idp_id: github
- # idp_name: Github
- # idp_brand: github
- # discover: false
- # issuer: "https://github.com/"
- # client_id: "your-client-id" # TO BE FILLED
- # client_secret: "your-client-secret" # TO BE FILLED
- # authorization_endpoint: "https://github.com/login/oauth/authorize"
- # token_endpoint: "https://github.com/login/oauth/access_token"
- # userinfo_endpoint: "https://api.github.com/user"
- # scopes: ["read:user"]
- # user_mapping_provider:
- # config:
- # subject_claim: "id"
- # localpart_template: "{{{{ user.login }}}}"
- # display_name_template: "{{{{ user.name }}}}"
""".format(
mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
)
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 3e993b428b..0c18c49abb 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -31,25 +31,39 @@ Events are replicated via a separate events stream.
import logging
from collections import namedtuple
-from typing import Dict, List, Tuple, Type
+from typing import (
+ TYPE_CHECKING,
+ Dict,
+ Hashable,
+ Iterable,
+ List,
+ Optional,
+ Sized,
+ Tuple,
+ Type,
+)
from sortedcontainers import SortedDict
-from twisted.internet import defer
-
from synapse.api.presence import UserPresenceState
+from synapse.federation.sender import AbstractFederationSender, FederationSender
from synapse.metrics import LaterGauge
+from synapse.replication.tcp.streams.federation import FederationStream
+from synapse.types import JsonDict, ReadReceipt, RoomStreamToken
from synapse.util.metrics import Measure
from .units import Edu
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
-class FederationRemoteSendQueue:
+class FederationRemoteSendQueue(AbstractFederationSender):
"""A drop in replacement for FederationSender"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
@@ -58,7 +72,7 @@ class FederationRemoteSendQueue:
# We may have multiple federation sender instances, so we need to track
# their positions separately.
self._sender_instances = hs.config.worker.federation_shard_config.instances
- self._sender_positions = {}
+ self._sender_positions = {} # type: Dict[str, int]
# Pending presence map user_id -> UserPresenceState
self.presence_map = {} # type: Dict[str, UserPresenceState]
@@ -71,7 +85,7 @@ class FederationRemoteSendQueue:
# Stream position -> (user_id, destinations)
self.presence_destinations = (
SortedDict()
- ) # type: SortedDict[int, Tuple[str, List[str]]]
+ ) # type: SortedDict[int, Tuple[str, Iterable[str]]]
# (destination, key) -> EDU
self.keyed_edu = {} # type: Dict[Tuple[str, tuple], Edu]
@@ -94,7 +108,7 @@ class FederationRemoteSendQueue:
# we make a new function, so we need to make a new function so the inner
# lambda binds to the queue rather than to the name of the queue which
# changes. ARGH.
- def register(name, queue):
+ def register(name: str, queue: Sized) -> None:
LaterGauge(
"synapse_federation_send_queue_%s_size" % (queue_name,),
"",
@@ -115,13 +129,13 @@ class FederationRemoteSendQueue:
self.clock.looping_call(self._clear_queue, 30 * 1000)
- def _next_pos(self):
+ def _next_pos(self) -> int:
pos = self.pos
self.pos += 1
self.pos_time[self.clock.time_msec()] = pos
return pos
- def _clear_queue(self):
+ def _clear_queue(self) -> None:
"""Clear the queues for anything older than N minutes"""
FIVE_MINUTES_AGO = 5 * 60 * 1000
@@ -138,7 +152,7 @@ class FederationRemoteSendQueue:
self._clear_queue_before_pos(position_to_delete)
- def _clear_queue_before_pos(self, position_to_delete):
+ def _clear_queue_before_pos(self, position_to_delete: int) -> None:
"""Clear all the queues from before a given position"""
with Measure(self.clock, "send_queue._clear"):
# Delete things out of presence maps
@@ -188,13 +202,18 @@ class FederationRemoteSendQueue:
for key in keys[:i]:
del self.edus[key]
- def notify_new_events(self, max_token):
+ def notify_new_events(self, max_token: RoomStreamToken) -> None:
"""As per FederationSender"""
- # We don't need to replicate this as it gets sent down a different
- # stream.
- pass
+ # This should never get called.
+ raise NotImplementedError()
- def build_and_send_edu(self, destination, edu_type, content, key=None):
+ def build_and_send_edu(
+ self,
+ destination: str,
+ edu_type: str,
+ content: JsonDict,
+ key: Optional[Hashable] = None,
+ ) -> None:
"""As per FederationSender"""
if destination == self.server_name:
logger.info("Not sending EDU to ourselves")
@@ -218,38 +237,39 @@ class FederationRemoteSendQueue:
self.notifier.on_new_replication_data()
- def send_read_receipt(self, receipt):
+ async def send_read_receipt(self, receipt: ReadReceipt) -> None:
"""As per FederationSender
Args:
- receipt (synapse.types.ReadReceipt):
+ receipt:
"""
# nothing to do here: the replication listener will handle it.
- return defer.succeed(None)
- def send_presence(self, states):
+ def send_presence(self, states: List[UserPresenceState]) -> None:
"""As per FederationSender
Args:
- states (list(UserPresenceState))
+ states
"""
pos = self._next_pos()
# We only want to send presence for our own users, so lets always just
# filter here just in case.
- local_states = list(filter(lambda s: self.is_mine_id(s.user_id), states))
+ local_states = [s for s in states if self.is_mine_id(s.user_id)]
self.presence_map.update({state.user_id: state for state in local_states})
self.presence_changed[pos] = [state.user_id for state in local_states]
self.notifier.on_new_replication_data()
- def send_presence_to_destinations(self, states, destinations):
+ def send_presence_to_destinations(
+ self, states: Iterable[UserPresenceState], destinations: Iterable[str]
+ ) -> None:
"""As per FederationSender
Args:
- states (list[UserPresenceState])
- destinations (list[str])
+ states
+ destinations
"""
for state in states:
pos = self._next_pos()
@@ -258,15 +278,18 @@ class FederationRemoteSendQueue:
self.notifier.on_new_replication_data()
- def send_device_messages(self, destination):
+ def send_device_messages(self, destination: str) -> None:
"""As per FederationSender"""
# We don't need to replicate this as it gets sent down a different
# stream.
- def get_current_token(self):
+ def wake_destination(self, server: str) -> None:
+ pass
+
+ def get_current_token(self) -> int:
return self.pos - 1
- def federation_ack(self, instance_name, token):
+ def federation_ack(self, instance_name: str, token: int) -> None:
if self._sender_instances:
# If we have configured multiple federation sender instances we need
# to track their positions separately, and only clear the queue up
@@ -504,13 +527,16 @@ ParsedFederationStreamData = namedtuple(
)
-def process_rows_for_federation(transaction_queue, rows):
+def process_rows_for_federation(
+ transaction_queue: FederationSender,
+ rows: List[FederationStream.FederationStreamRow],
+) -> None:
"""Parse a list of rows from the federation stream and put them in the
transaction queue ready for sending to the relevant homeservers.
Args:
- transaction_queue (FederationSender)
- rows (list(synapse.replication.tcp.streams.federation.FederationStream.FederationStreamRow))
+ transaction_queue
+ rows
"""
# The federation stream contains a bunch of different types of
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 24ebc4b803..8babb1ebbe 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import logging
-from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Set, Tuple
from prometheus_client import Counter
from twisted.internet import defer
-import synapse
import synapse.metrics
from synapse.api.presence import UserPresenceState
from synapse.events import EventBase
@@ -40,9 +40,12 @@ from synapse.metrics import (
events_processed_counter,
)
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import ReadReceipt, RoomStreamToken
+from synapse.types import JsonDict, ReadReceipt, RoomStreamToken
from synapse.util.metrics import Measure, measure_func
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
sent_pdus_destination_dist_count = Counter(
@@ -65,8 +68,91 @@ CATCH_UP_STARTUP_DELAY_SEC = 15
CATCH_UP_STARTUP_INTERVAL_SEC = 5
-class FederationSender:
- def __init__(self, hs: "synapse.server.HomeServer"):
+class AbstractFederationSender(metaclass=abc.ABCMeta):
+ @abc.abstractmethod
+ def notify_new_events(self, max_token: RoomStreamToken) -> None:
+ """This gets called when we have some new events we might want to
+ send out to other servers.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ async def send_read_receipt(self, receipt: ReadReceipt) -> None:
+ """Send a RR to any other servers in the room
+
+ Args:
+ receipt: receipt to be sent
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def send_presence(self, states: List[UserPresenceState]) -> None:
+ """Send the new presence states to the appropriate destinations.
+
+ This actually queues up the presence states ready for sending and
+ triggers a background task to process them and send out the transactions.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def send_presence_to_destinations(
+ self, states: Iterable[UserPresenceState], destinations: Iterable[str]
+ ) -> None:
+ """Send the given presence states to the given destinations.
+
+ Args:
+ destinations:
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def build_and_send_edu(
+ self,
+ destination: str,
+ edu_type: str,
+ content: JsonDict,
+ key: Optional[Hashable] = None,
+ ) -> None:
+ """Construct an Edu object, and queue it for sending
+
+ Args:
+ destination: name of server to send to
+ edu_type: type of EDU to send
+ content: content of EDU
+ key: clobbering key for this edu
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def send_device_messages(self, destination: str) -> None:
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def wake_destination(self, destination: str) -> None:
+ """Called when we want to retry sending transactions to a remote.
+
+ This is mainly useful if the remote server has been down and we think it
+ might have come back.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_current_token(self) -> int:
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def federation_ack(self, instance_name: str, token: int) -> None:
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ async def get_replication_rows(
+ self, instance_name: str, from_token: int, to_token: int, target_row_count: int
+ ) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
+ raise NotImplementedError()
+
+
+class FederationSender(AbstractFederationSender):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.server_name = hs.hostname
@@ -432,7 +518,7 @@ class FederationSender:
queue.flush_read_receipts_for_room(room_id)
@preserve_fn # the caller should not yield on this
- async def send_presence(self, states: List[UserPresenceState]):
+ async def send_presence(self, states: List[UserPresenceState]) -> None:
"""Send the new presence states to the appropriate destinations.
This actually queues up the presence states ready for sending and
@@ -494,7 +580,7 @@ class FederationSender:
self._get_per_destination_queue(destination).send_presence(states)
@measure_func("txnqueue._process_presence")
- async def _process_presence_inner(self, states: List[UserPresenceState]):
+ async def _process_presence_inner(self, states: List[UserPresenceState]) -> None:
"""Given a list of states populate self.pending_presence_by_dest and
poke to send a new transaction to each destination
"""
@@ -516,9 +602,9 @@ class FederationSender:
self,
destination: str,
edu_type: str,
- content: dict,
+ content: JsonDict,
key: Optional[Hashable] = None,
- ):
+ ) -> None:
"""Construct an Edu object, and queue it for sending
Args:
@@ -545,7 +631,7 @@ class FederationSender:
self.send_edu(edu, key)
- def send_edu(self, edu: Edu, key: Optional[Hashable]):
+ def send_edu(self, edu: Edu, key: Optional[Hashable]) -> None:
"""Queue an EDU for sending
Args:
@@ -563,7 +649,7 @@ class FederationSender:
else:
queue.send_edu(edu)
- def send_device_messages(self, destination: str):
+ def send_device_messages(self, destination: str) -> None:
if destination == self.server_name:
logger.warning("Not sending device update to ourselves")
return
@@ -575,7 +661,7 @@ class FederationSender:
self._get_per_destination_queue(destination).attempt_new_transaction()
- def wake_destination(self, destination: str):
+ def wake_destination(self, destination: str) -> None:
"""Called when we want to retry sending transactions to a remote.
This is mainly useful if the remote server has been down and we think it
@@ -599,6 +685,10 @@ class FederationSender:
# to a worker.
return 0
+ def federation_ack(self, instance_name: str, token: int) -> None:
+ # It is not expected that this gets called on FederationSender.
+ raise NotImplementedError()
+
@staticmethod
async def get_replication_rows(
instance_name: str, from_token: int, to_token: int, target_row_count: int
@@ -607,7 +697,7 @@ class FederationSender:
# to a worker.
return [], 0, False
- async def _wake_destinations_needing_catchup(self):
+ async def _wake_destinations_needing_catchup(self) -> None:
"""
Wakes up destinations that need catch-up and are not currently being
backed off from.
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index bc3630e9e9..6624212d6f 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -149,6 +149,9 @@ class OidcHandler:
Args:
request: the incoming request from the browser.
"""
+ # This will always be set by the time Twisted calls us.
+ assert request.args is not None
+
# The provider might redirect with an error.
# In that case, just display it as-is.
if b"error" in request.args:
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index ecd63e6596..ce4079f15c 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -71,8 +71,10 @@ WELL_KNOWN_RETRY_ATTEMPTS = 3
logger = logging.getLogger(__name__)
-_well_known_cache = TTLCache("well-known")
-_had_valid_well_known_cache = TTLCache("had-valid-well-known")
+_well_known_cache = TTLCache("well-known") # type: TTLCache[bytes, Optional[bytes]]
+_had_valid_well_known_cache = TTLCache(
+ "had-valid-well-known"
+) # type: TTLCache[bytes, bool]
@attr.s(slots=True, frozen=True)
@@ -88,8 +90,8 @@ class WellKnownResolver:
reactor: IReactorTime,
agent: IAgent,
user_agent: bytes,
- well_known_cache: Optional[TTLCache] = None,
- had_well_known_cache: Optional[TTLCache] = None,
+ well_known_cache: Optional[TTLCache[bytes, Optional[bytes]]] = None,
+ had_well_known_cache: Optional[TTLCache[bytes, bool]] = None,
):
self._reactor = reactor
self._clock = Clock(reactor)
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 10bd4a1461..aa146e8bb8 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -169,7 +169,7 @@ import inspect
import logging
import re
from functools import wraps
-from typing import TYPE_CHECKING, Dict, Optional, Type
+from typing import TYPE_CHECKING, Dict, Optional, Pattern, Type
import attr
@@ -262,7 +262,7 @@ logger = logging.getLogger(__name__)
# Block everything by default
# A regex which matches the server_names to expose traces for.
# None means 'block everything'.
-_homeserver_whitelist = None
+_homeserver_whitelist = None # type: Optional[Pattern[str]]
# Util methods
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 9e04b266e4..08350292ab 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import itertools
import logging
from typing import List, Set
@@ -101,7 +102,7 @@ CONDITIONAL_REQUIREMENTS = {
"txacme>=0.9.2",
# txacme depends on eliot. Eliot 1.8.0 is incompatible with
# python 3.5.2, as per https://github.com/itamarst/eliot/issues/418
- 'eliot<1.8.0;python_version<"3.5.3"',
+ "eliot<1.8.0;python_version<'3.5.3'",
],
"saml2": [
# pysaml2 6.4.0 is incompatible with Python 3.5 (see https://github.com/IdentityPython/pysaml2/issues/749)
@@ -133,6 +134,18 @@ for name, optional_deps in CONDITIONAL_REQUIREMENTS.items():
ALL_OPTIONAL_REQUIREMENTS = set(optional_deps) | ALL_OPTIONAL_REQUIREMENTS
+# ensure there are no double-quote characters in any of the deps (otherwise the
+# 'pip install' incantation in DependencyException will break)
+for dep in itertools.chain(
+ REQUIREMENTS,
+ *CONDITIONAL_REQUIREMENTS.values(),
+):
+ if '"' in dep:
+ raise Exception(
+ "Dependency `%s` contains double-quote; use single-quotes instead" % (dep,)
+ )
+
+
def list_requirements():
return list(set(REQUIREMENTS) | ALL_OPTIONAL_REQUIREMENTS)
@@ -152,7 +165,7 @@ class DependencyException(Exception):
@property
def dependencies(self):
for i in self.args[0]:
- yield "'" + i + "'"
+ yield '"' + i + '"'
def check_requirements(for_feature=None):
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index bb447f75b4..8abed1f52d 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -312,16 +312,16 @@ class FederationAckCommand(Command):
NAME = "FEDERATION_ACK"
- def __init__(self, instance_name, token):
+ def __init__(self, instance_name: str, token: int):
self.instance_name = instance_name
self.token = token
@classmethod
- def from_line(cls, line):
+ def from_line(cls, line: str) -> "FederationAckCommand":
instance_name, token = line.split(" ")
return cls(instance_name, int(token))
- def to_line(self):
+ def to_line(self) -> str:
return "%s %s" % (self.instance_name, self.token)
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 825900f64c..e829add257 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -104,7 +104,7 @@ tcp_outbound_commands_counter = Counter(
# A list of all connected protocols. This allows us to send metrics about the
# connections.
-connected_connections = []
+connected_connections = [] # type: List[BaseReplicationStreamProtocol]
logger = logging.getLogger(__name__)
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 9bcd13b009..9bb8e9e177 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Tuple
from synapse.replication.tcp.streams._base import (
Stream,
@@ -21,6 +22,9 @@ from synapse.replication.tcp.streams._base import (
make_http_update_function,
)
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
class FederationStream(Stream):
"""Data to be sent over federation. Only available when master has federation
@@ -38,7 +42,7 @@ class FederationStream(Stream):
NAME = "federation"
ROW_TYPE = FederationStreamRow
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
if hs.config.worker_app is None:
# master process: get updates from the FederationRemoteSendQueue.
# (if the master is configured to send federation itself, federation_sender
@@ -48,7 +52,9 @@ class FederationStream(Stream):
current_token = current_token_without_instance(
federation_sender.get_current_token
)
- update_function = federation_sender.get_replication_rows
+ update_function = (
+ federation_sender.get_replication_rows
+ ) # type: Callable[[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]]
elif hs.should_send_federation():
# federation sender: Query master process
@@ -69,5 +75,7 @@ class FederationStream(Stream):
return 0
@staticmethod
- async def _stub_update_function(instance_name, from_token, upto_token, limit):
+ async def _stub_update_function(
+ instance_name: str, from_token: int, upto_token: int, limit: int
+ ) -> Tuple[list, int, bool]:
return [], upto_token, False
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 263d8ec076..cfe1bebb91 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -390,6 +390,9 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
async def on_POST(
self, request: SynapseRequest, room_identifier: str
) -> Tuple[int, JsonDict]:
+ # This will always be set by the time Twisted calls us.
+ assert request.args is not None
+
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index aaa56a7024..309bd2771b 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -833,6 +833,9 @@ class UserMediaRestServlet(RestServlet):
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
+ # This will always be set by the time Twisted calls us.
+ assert request.args is not None
+
await assert_requester_is_admin(self.auth, request)
if not self.is_mine(UserID.from_string(user_id)):
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 9452e7ca9f..c01ba14cd2 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -90,6 +90,9 @@ class SyncRestServlet(RestServlet):
self._event_serializer = hs.get_event_client_serializer()
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ # This will always be set by the time Twisted calls us.
+ assert request.args is not None
+
if b"from" in request.args:
# /events used to use 'from', but /sync uses 'since'.
# Lets be helpful and whine if we see a 'from'.
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index e590a0deab..c4ed9dfdb4 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -187,6 +187,8 @@ class PreviewUrlResource(DirectServeJsonResource):
respond_with_json(request, 200, {}, send_cors=True)
async def _async_render_GET(self, request: SynapseRequest) -> None:
+ # This will always be set by the time Twisted calls us.
+ assert request.args is not None
# XXX: if get_user_by_req fails, what should we do in an async render?
requester = await self.auth.get_user_by_req(request)
diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py
index 51acaa9a92..d9ffe84489 100644
--- a/synapse/rest/synapse/client/pick_username.py
+++ b/synapse/rest/synapse/client/pick_username.py
@@ -104,6 +104,9 @@ class AccountDetailsResource(DirectServeHtmlResource):
respond_with_html(request, 200, html)
async def _async_render_POST(self, request: SynapseRequest):
+ # This will always be set by the time Twisted calls us.
+ assert request.args is not None
+
try:
session_id = get_username_mapping_session_cookie_from_request(request)
except SynapseError as e:
diff --git a/synapse/server.py b/synapse/server.py
index 5e787e2281..e85b9391fa 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -60,7 +60,7 @@ from synapse.federation.federation_server import (
FederationServer,
)
from synapse.federation.send_queue import FederationRemoteSendQueue
-from synapse.federation.sender import FederationSender
+from synapse.federation.sender import AbstractFederationSender, FederationSender
from synapse.federation.transport.client import TransportLayerClient
from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler
@@ -571,7 +571,7 @@ class HomeServer(metaclass=abc.ABCMeta):
return TransportLayerClient(self)
@cache_in_self
- def get_federation_sender(self):
+ def get_federation_sender(self) -> AbstractFederationSender:
if self.should_send_federation():
return FederationSender(self)
elif not self.config.worker_app:
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index e2240703a7..97ec65f757 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -183,12 +183,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
requests state from the cache, if False we need to query the DB for the
missing state.
"""
- is_all, known_absent, state_dict_ids = cache.get(group)
+ cache_entry = cache.get(group)
+ state_dict_ids = cache_entry.value
- if is_all or state_filter.is_full():
+ if cache_entry.full or state_filter.is_full():
# Either we have everything or want everything, either way
# `is_all` tells us whether we've gotten everything.
- return state_filter.filter_state(state_dict_ids), is_all
+ return state_filter.filter_state(state_dict_ids), cache_entry.full
# tracks whether any of our requested types are missing from the cache
missing_types = False
@@ -202,7 +203,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# There aren't any wild cards, so `concrete_types()` returns the
# complete list of event types we're wanting.
for key in state_filter.concrete_types():
- if key not in state_dict_ids and key not in known_absent:
+ if key not in state_dict_ids and key not in cache_entry.known_absent:
missing_types = True
break
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index f968706334..48f64eeb38 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -25,8 +25,8 @@ from synapse.config.cache import add_resizable_cache
logger = logging.getLogger(__name__)
-caches_by_name = {}
-collectors_by_name = {} # type: Dict
+caches_by_name = {} # type: Dict[str, Sized]
+collectors_by_name = {} # type: Dict[str, CacheMetric]
cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 588d2d49f2..b3b413b02c 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -15,26 +15,38 @@
import enum
import logging
import threading
-from collections import namedtuple
-from typing import Any
+from typing import Any, Dict, Generic, Iterable, Optional, Set, TypeVar
+
+import attr
from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__)
-class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "value"))):
+# The type of the cache keys.
+KT = TypeVar("KT")
+# The type of the dictionary keys.
+DKT = TypeVar("DKT")
+
+
+@attr.s(slots=True)
+class DictionaryEntry:
"""Returned when getting an entry from the cache
Attributes:
- full (bool): Whether the cache has the full or dict or just some keys.
+ full: Whether the cache has the full or dict or just some keys.
If not full then not all requested keys will necessarily be present
in `value`
- known_absent (set): Keys that were looked up in the dict and were not
+ known_absent: Keys that were looked up in the dict and were not
there.
- value (dict): The full or partial dict value
+ value: The full or partial dict value
"""
+ full = attr.ib(type=bool)
+ known_absent = attr.ib()
+ value = attr.ib()
+
def __len__(self):
return len(self.value)
@@ -45,21 +57,21 @@ class _Sentinel(enum.Enum):
sentinel = object()
-class DictionaryCache:
+class DictionaryCache(Generic[KT, DKT]):
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
fetching a subset of dictionary keys for a particular key.
"""
- def __init__(self, name, max_entries=1000):
+ def __init__(self, name: str, max_entries: int = 1000):
self.cache = LruCache(
max_size=max_entries, cache_name=name, size_callback=len
- ) # type: LruCache[Any, DictionaryEntry]
+ ) # type: LruCache[KT, DictionaryEntry]
self.name = name
self.sequence = 0
- self.thread = None
+ self.thread = None # type: Optional[threading.Thread]
- def check_thread(self):
+ def check_thread(self) -> None:
expected_thread = self.thread
if expected_thread is None:
self.thread = threading.current_thread()
@@ -69,12 +81,14 @@ class DictionaryCache:
"Cache objects can only be accessed from the main thread"
)
- def get(self, key, dict_keys=None):
+ def get(
+ self, key: KT, dict_keys: Optional[Iterable[DKT]] = None
+ ) -> DictionaryEntry:
"""Fetch an entry out of the cache
Args:
key
- dict_key(list): If given a set of keys then return only those keys
+ dict_key: If given a set of keys then return only those keys
that exist in the cache.
Returns:
@@ -95,7 +109,7 @@ class DictionaryCache:
return DictionaryEntry(False, set(), {})
- def invalidate(self, key):
+ def invalidate(self, key: KT) -> None:
self.check_thread()
# Increment the sequence number so that any SELECT statements that
@@ -103,19 +117,25 @@ class DictionaryCache:
self.sequence += 1
self.cache.pop(key, None)
- def invalidate_all(self):
+ def invalidate_all(self) -> None:
self.check_thread()
self.sequence += 1
self.cache.clear()
- def update(self, sequence, key, value, fetched_keys=None):
+ def update(
+ self,
+ sequence: int,
+ key: KT,
+ value: Dict[DKT, Any],
+ fetched_keys: Optional[Set[DKT]] = None,
+ ) -> None:
"""Updates the entry in the cache
Args:
sequence
- key (K)
- value (dict[X,Y]): The value to update the cache with.
- fetched_keys (None|set[X]): All of the dictionary keys which were
+ key
+ value: The value to update the cache with.
+ fetched_keys: All of the dictionary keys which were
fetched from the database.
If None, this is the complete value for key K. Otherwise, it
@@ -131,7 +151,9 @@ class DictionaryCache:
else:
self._update_or_insert(key, value, fetched_keys)
- def _update_or_insert(self, key, value, known_absent):
+ def _update_or_insert(
+ self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]
+ ) -> None:
# We pop and reinsert as we need to tell the cache the size may have
# changed
@@ -140,5 +162,5 @@ class DictionaryCache:
entry.known_absent.update(known_absent)
self.cache[key] = entry
- def _insert(self, key, value, known_absent):
+ def _insert(self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]) -> None:
self.cache[key] = DictionaryEntry(True, known_absent, value)
diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index 6ce2a3d12b..96a8274940 100644
--- a/synapse/util/caches/ttlcache.py
+++ b/synapse/util/caches/ttlcache.py
@@ -15,6 +15,7 @@
import logging
import time
+from typing import Any, Callable, Dict, Generic, Tuple, TypeVar, Union
import attr
from sortedcontainers import SortedList
@@ -23,15 +24,19 @@ from synapse.util.caches import register_cache
logger = logging.getLogger(__name__)
-SENTINEL = object()
+SENTINEL = object() # type: Any
+T = TypeVar("T")
+KT = TypeVar("KT")
+VT = TypeVar("VT")
-class TTLCache:
+
+class TTLCache(Generic[KT, VT]):
"""A key/value cache implementation where each entry has its own TTL"""
- def __init__(self, cache_name, timer=time.time):
+ def __init__(self, cache_name: str, timer: Callable[[], float] = time.time):
# map from key to _CacheEntry
- self._data = {}
+ self._data = {} # type: Dict[KT, _CacheEntry]
# the _CacheEntries, sorted by expiry time
self._expiry_list = SortedList() # type: SortedList[_CacheEntry]
@@ -40,26 +45,27 @@ class TTLCache:
self._metrics = register_cache("ttl", cache_name, self, resizable=False)
- def set(self, key, value, ttl):
+ def set(self, key: KT, value: VT, ttl: float) -> None:
"""Add/update an entry in the cache
Args:
key: key for this entry
value: value for this entry
- ttl (float): TTL for this entry, in seconds
+ ttl: TTL for this entry, in seconds
"""
expiry = self._timer() + ttl
self.expire()
e = self._data.pop(key, SENTINEL)
- if e != SENTINEL:
+ if e is not SENTINEL:
+ assert isinstance(e, _CacheEntry)
self._expiry_list.remove(e)
entry = _CacheEntry(expiry_time=expiry, ttl=ttl, key=key, value=value)
self._data[key] = entry
self._expiry_list.add(entry)
- def get(self, key, default=SENTINEL):
+ def get(self, key: KT, default: T = SENTINEL) -> Union[VT, T]:
"""Get a value from the cache
Args:
@@ -72,23 +78,23 @@ class TTLCache:
"""
self.expire()
e = self._data.get(key, SENTINEL)
- if e == SENTINEL:
+ if e is SENTINEL:
self._metrics.inc_misses()
- if default == SENTINEL:
+ if default is SENTINEL:
raise KeyError(key)
return default
+ assert isinstance(e, _CacheEntry)
self._metrics.inc_hits()
return e.value
- def get_with_expiry(self, key):
+ def get_with_expiry(self, key: KT) -> Tuple[VT, float, float]:
"""Get a value, and its expiry time, from the cache
Args:
key: key to look up
Returns:
- Tuple[Any, float, float]: the value from the cache, the expiry time
- and the TTL
+ A tuple of the value from the cache, the expiry time and the TTL
Raises:
KeyError if the entry is not found
@@ -102,7 +108,7 @@ class TTLCache:
self._metrics.inc_hits()
return e.value, e.expiry_time, e.ttl
- def pop(self, key, default=SENTINEL):
+ def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: # type: ignore
"""Remove a value from the cache
If key is in the cache, remove it and return its value, else return default.
@@ -118,29 +124,30 @@ class TTLCache:
"""
self.expire()
e = self._data.pop(key, SENTINEL)
- if e == SENTINEL:
+ if e is SENTINEL:
self._metrics.inc_misses()
- if default == SENTINEL:
+ if default is SENTINEL:
raise KeyError(key)
return default
+ assert isinstance(e, _CacheEntry)
self._expiry_list.remove(e)
self._metrics.inc_hits()
return e.value
- def __getitem__(self, key):
+ def __getitem__(self, key: KT) -> VT:
return self.get(key)
- def __delitem__(self, key):
+ def __delitem__(self, key: KT) -> None:
self.pop(key)
- def __contains__(self, key):
+ def __contains__(self, key: KT) -> bool:
return key in self._data
- def __len__(self):
+ def __len__(self) -> int:
self.expire()
return len(self._data)
- def expire(self):
+ def expire(self) -> None:
"""Run the expiry on the cache. Any entries whose expiry times are due will
be removed
"""
@@ -158,7 +165,7 @@ class _CacheEntry:
"""TTLCache entry"""
# expiry_time is the first attribute, so that entries are sorted by expiry.
- expiry_time = attr.ib()
- ttl = attr.ib()
+ expiry_time = attr.ib(type=float)
+ ttl = attr.ib(type=float)
key = attr.ib()
value = attr.ib()
|