summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/acme.py12
-rw-r--r--synapse/handlers/acme_issuing_service.py27
-rw-r--r--synapse/handlers/cas_handler.py6
-rw-r--r--synapse/handlers/federation.py5
-rw-r--r--synapse/handlers/groups_local.py83
-rw-r--r--synapse/handlers/message.py42
-rw-r--r--synapse/handlers/search.py38
-rw-r--r--synapse/handlers/set_password.py10
-rw-r--r--synapse/handlers/state_deltas.py14
-rw-r--r--synapse/handlers/stats.py39
-rw-r--r--synapse/handlers/typing.py69
-rw-r--r--synapse/handlers/user_directory.py9
12 files changed, 222 insertions, 132 deletions
diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
index 8476256a59..5ecb2da1ac 100644
--- a/synapse/handlers/acme.py
+++ b/synapse/handlers/acme.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
 
 import twisted
 import twisted.internet.error
@@ -22,6 +23,9 @@ from twisted.web.resource import Resource
 
 from synapse.app import check_bind_error
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 ACME_REGISTER_FAIL_ERROR = """
@@ -35,12 +39,12 @@ solutions, please read https://github.com/matrix-org/synapse/blob/master/docs/AC
 
 
 class AcmeHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.reactor = hs.get_reactor()
         self._acme_domain = hs.config.acme_domain
 
-    async def start_listening(self):
+    async def start_listening(self) -> None:
         from synapse.handlers import acme_issuing_service
 
         # Configure logging for txacme, if you need to debug
@@ -85,7 +89,7 @@ class AcmeHandler:
             logger.error(ACME_REGISTER_FAIL_ERROR)
             raise
 
-    async def provision_certificate(self):
+    async def provision_certificate(self) -> None:
 
         logger.warning("Reprovisioning %s", self._acme_domain)
 
@@ -110,5 +114,3 @@ class AcmeHandler:
         except Exception:
             logger.exception("Failed saving!")
             raise
-
-        return True
diff --git a/synapse/handlers/acme_issuing_service.py b/synapse/handlers/acme_issuing_service.py
index 7294649d71..ae2a9dd9c2 100644
--- a/synapse/handlers/acme_issuing_service.py
+++ b/synapse/handlers/acme_issuing_service.py
@@ -22,8 +22,10 @@ only need (and may only have available) if we are doing ACME, so is designed to
 imported conditionally.
 """
 import logging
+from typing import Dict, Iterable, List
 
 import attr
+import pem
 from cryptography.hazmat.backends import default_backend
 from cryptography.hazmat.primitives import serialization
 from josepy import JWKRSA
@@ -36,20 +38,27 @@ from txacme.util import generate_private_key
 from zope.interface import implementer
 
 from twisted.internet import defer
+from twisted.internet.interfaces import IReactorTCP
 from twisted.python.filepath import FilePath
 from twisted.python.url import URL
+from twisted.web.resource import IResource
 
 logger = logging.getLogger(__name__)
 
 
-def create_issuing_service(reactor, acme_url, account_key_file, well_known_resource):
+def create_issuing_service(
+    reactor: IReactorTCP,
+    acme_url: str,
+    account_key_file: str,
+    well_known_resource: IResource,
+) -> AcmeIssuingService:
     """Create an ACME issuing service, and attach it to a web Resource
 
     Args:
         reactor: twisted reactor
-        acme_url (str): URL to use to request certificates
-        account_key_file (str): where to store the account key
-        well_known_resource (twisted.web.IResource): web resource for .well-known.
+        acme_url: URL to use to request certificates
+        account_key_file: where to store the account key
+        well_known_resource: web resource for .well-known.
             we will attach a child resource for "acme-challenge".
 
     Returns:
@@ -83,18 +92,20 @@ class ErsatzStore:
     A store that only stores in memory.
     """
 
-    certs = attr.ib(default=attr.Factory(dict))
+    certs = attr.ib(type=Dict[bytes, List[bytes]], default=attr.Factory(dict))
 
-    def store(self, server_name, pem_objects):
+    def store(
+        self, server_name: bytes, pem_objects: Iterable[pem.AbstractPEMObject]
+    ) -> defer.Deferred:
         self.certs[server_name] = [o.as_bytes() for o in pem_objects]
         return defer.succeed(None)
 
 
-def load_or_create_client_key(key_file):
+def load_or_create_client_key(key_file: str) -> JWKRSA:
     """Load the ACME account key from a file, creating it if it does not exist.
 
     Args:
-        key_file (str): name of the file to use as the account key
+        key_file: name of the file to use as the account key
     """
     # this is based on txacme.endpoint.load_or_create_client_key, but doesn't
     # hardcode the 'client.key' filename
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 0f342c607b..21b6bc4992 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -99,11 +99,7 @@ class CasHandler:
         Returns:
             The URL to use as a "service" parameter.
         """
-        return "%s%s?%s" % (
-            self._cas_service_url,
-            "/_matrix/client/r0/login/cas/ticket",
-            urllib.parse.urlencode(args),
-        )
+        return "%s?%s" % (self._cas_service_url, urllib.parse.urlencode(args),)
 
     async def _validate_ticket(
         self, ticket: str, service_args: Dict[str, str]
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index fd8de8696d..b6dc7f99b6 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -2093,6 +2093,11 @@ class FederationHandler(BaseHandler):
         if event.type == EventTypes.GuestAccess and not context.rejected:
             await self.maybe_kick_guest_users(event)
 
+        # If we are going to send this event over federation we precaclculate
+        # the joined hosts.
+        if event.internal_metadata.get_send_on_behalf_of():
+            await self.event_creation_handler.cache_joined_hosts_for_event(event)
+
         return context
 
     async def _check_for_soft_fail(
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index df29edeb83..71f11ef94a 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -15,9 +15,13 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Dict, Iterable, List, Set
 
 from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
-from synapse.types import GroupID, get_domain_from_id
+from synapse.types import GroupID, JsonDict, get_domain_from_id
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -56,7 +60,7 @@ def _create_rerouter(func_name):
 
 
 class GroupsLocalWorkerHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastore()
         self.room_list_handler = hs.get_room_list_handler()
@@ -84,7 +88,9 @@ class GroupsLocalWorkerHandler:
     get_group_role = _create_rerouter("get_group_role")
     get_group_roles = _create_rerouter("get_group_roles")
 
-    async def get_group_summary(self, group_id, requester_user_id):
+    async def get_group_summary(
+        self, group_id: str, requester_user_id: str
+    ) -> JsonDict:
         """Get the group summary for a group.
 
         If the group is remote we check that the users have valid attestations.
@@ -137,14 +143,15 @@ class GroupsLocalWorkerHandler:
 
         return res
 
-    async def get_users_in_group(self, group_id, requester_user_id):
+    async def get_users_in_group(
+        self, group_id: str, requester_user_id: str
+    ) -> JsonDict:
         """Get users in a group
         """
         if self.is_mine_id(group_id):
-            res = await self.groups_server_handler.get_users_in_group(
+            return await self.groups_server_handler.get_users_in_group(
                 group_id, requester_user_id
             )
-            return res
 
         group_server_name = get_domain_from_id(group_id)
 
@@ -178,11 +185,11 @@ class GroupsLocalWorkerHandler:
 
         return res
 
-    async def get_joined_groups(self, user_id):
+    async def get_joined_groups(self, user_id: str) -> JsonDict:
         group_ids = await self.store.get_joined_groups(user_id)
         return {"groups": group_ids}
 
-    async def get_publicised_groups_for_user(self, user_id):
+    async def get_publicised_groups_for_user(self, user_id: str) -> JsonDict:
         if self.hs.is_mine_id(user_id):
             result = await self.store.get_publicised_groups_for_user(user_id)
 
@@ -206,8 +213,10 @@ class GroupsLocalWorkerHandler:
             # TODO: Verify attestations
             return {"groups": result}
 
-    async def bulk_get_publicised_groups(self, user_ids, proxy=True):
-        destinations = {}
+    async def bulk_get_publicised_groups(
+        self, user_ids: Iterable[str], proxy: bool = True
+    ) -> JsonDict:
+        destinations = {}  # type: Dict[str, Set[str]]
         local_users = set()
 
         for user_id in user_ids:
@@ -220,7 +229,7 @@ class GroupsLocalWorkerHandler:
             raise SynapseError(400, "Some user_ids are not local")
 
         results = {}
-        failed_results = []
+        failed_results = []  # type: List[str]
         for destination, dest_user_ids in destinations.items():
             try:
                 r = await self.transport_client.bulk_get_publicised_groups(
@@ -242,7 +251,7 @@ class GroupsLocalWorkerHandler:
 
 
 class GroupsLocalHandler(GroupsLocalWorkerHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         # Ensure attestations get renewed
@@ -271,7 +280,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
     set_group_join_policy = _create_rerouter("set_group_join_policy")
 
-    async def create_group(self, group_id, user_id, content):
+    async def create_group(
+        self, group_id: str, user_id: str, content: JsonDict
+    ) -> JsonDict:
         """Create a group
         """
 
@@ -284,27 +295,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
             local_attestation = None
             remote_attestation = None
         else:
-            local_attestation = self.attestations.create_attestation(group_id, user_id)
-            content["attestation"] = local_attestation
-
-            content["user_profile"] = await self.profile_handler.get_profile(user_id)
-
-            try:
-                res = await self.transport_client.create_group(
-                    get_domain_from_id(group_id), group_id, user_id, content
-                )
-            except HttpResponseException as e:
-                raise e.to_synapse_error()
-            except RequestSendFailed:
-                raise SynapseError(502, "Failed to contact group server")
-
-            remote_attestation = res["attestation"]
-            await self.attestations.verify_attestation(
-                remote_attestation,
-                group_id=group_id,
-                user_id=user_id,
-                server_name=get_domain_from_id(group_id),
-            )
+            raise SynapseError(400, "Unable to create remote groups")
 
         is_publicised = content.get("publicise", False)
         token = await self.store.register_user_group_membership(
@@ -320,7 +311,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
         return res
 
-    async def join_group(self, group_id, user_id, content):
+    async def join_group(
+        self, group_id: str, user_id: str, content: JsonDict
+    ) -> JsonDict:
         """Request to join a group
         """
         if self.is_mine_id(group_id):
@@ -365,7 +358,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
         return {}
 
-    async def accept_invite(self, group_id, user_id, content):
+    async def accept_invite(
+        self, group_id: str, user_id: str, content: JsonDict
+    ) -> JsonDict:
         """Accept an invite to a group
         """
         if self.is_mine_id(group_id):
@@ -410,7 +405,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
         return {}
 
-    async def invite(self, group_id, user_id, requester_user_id, config):
+    async def invite(
+        self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict
+    ) -> JsonDict:
         """Invite a user to a group
         """
         content = {"requester_user_id": requester_user_id, "config": config}
@@ -434,7 +431,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
         return res
 
-    async def on_invite(self, group_id, user_id, content):
+    async def on_invite(
+        self, group_id: str, user_id: str, content: JsonDict
+    ) -> JsonDict:
         """One of our users were invited to a group
         """
         # TODO: Support auto join and rejection
@@ -465,8 +464,8 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
         return {"state": "invite", "user_profile": user_profile}
 
     async def remove_user_from_group(
-        self, group_id, user_id, requester_user_id, content
-    ):
+        self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
+    ) -> JsonDict:
         """Remove a user from a group
         """
         if user_id == requester_user_id:
@@ -499,7 +498,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
         return res
 
-    async def user_removed_from_group(self, group_id, user_id, content):
+    async def user_removed_from_group(
+        self, group_id: str, user_id: str, content: JsonDict
+    ) -> None:
         """One of our users was removed/kicked from a group
         """
         # TODO: Check if user in group
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 9dfeab09cd..e2a7d567fa 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -432,6 +432,8 @@ class EventCreationHandler:
 
         self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
 
+        self._external_cache = hs.get_external_cache()
+
     async def create_event(
         self,
         requester: Requester,
@@ -939,6 +941,8 @@ class EventCreationHandler:
 
         await self.action_generator.handle_push_actions_for_event(event, context)
 
+        await self.cache_joined_hosts_for_event(event)
+
         try:
             # If we're a worker we need to hit out to the master.
             writer_instance = self._events_shard_config.get_instance(event.room_id)
@@ -978,6 +982,44 @@ class EventCreationHandler:
             await self.store.remove_push_actions_from_staging(event.event_id)
             raise
 
+    async def cache_joined_hosts_for_event(self, event: EventBase) -> None:
+        """Precalculate the joined hosts at the event, when using Redis, so that
+        external federation senders don't have to recalculate it themselves.
+        """
+
+        if not self._external_cache.is_enabled():
+            return
+
+        # We actually store two mappings, event ID -> prev state group,
+        # state group -> joined hosts, which is much more space efficient
+        # than event ID -> joined hosts.
+        #
+        # Note: We have to cache event ID -> prev state group, as we don't
+        # store that in the DB.
+        #
+        # Note: We always set the state group -> joined hosts cache, even if
+        # we already set it, so that the expiry time is reset.
+
+        state_entry = await self.state.resolve_state_groups_for_events(
+            event.room_id, event_ids=event.prev_event_ids()
+        )
+
+        if state_entry.state_group:
+            joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
+
+            await self._external_cache.set(
+                "event_to_prev_state_group",
+                event.event_id,
+                state_entry.state_group,
+                expiry_ms=60 * 60 * 1000,
+            )
+            await self._external_cache.set(
+                "get_joined_hosts",
+                str(state_entry.state_group),
+                list(joined_hosts),
+                expiry_ms=60 * 60 * 1000,
+            )
+
     async def _validate_canonical_alias(
         self, directory_handler, room_alias_str: str, expected_room_id: str
     ) -> None:
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 66f1bbcfc4..94062e79cb 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -15,23 +15,28 @@
 
 import itertools
 import logging
-from typing import Iterable
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional
 
 from unpaddedbase64 import decode_base64, encode_base64
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import NotFoundError, SynapseError
 from synapse.api.filtering import Filter
+from synapse.events import EventBase
 from synapse.storage.state import StateFilter
+from synapse.types import JsonDict, UserID
 from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class SearchHandler(BaseHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self._event_serializer = hs.get_event_client_serializer()
         self.storage = hs.get_storage()
@@ -87,13 +92,15 @@ class SearchHandler(BaseHandler):
 
         return historical_room_ids
 
-    async def search(self, user, content, batch=None):
+    async def search(
+        self, user: UserID, content: JsonDict, batch: Optional[str] = None
+    ) -> JsonDict:
         """Performs a full text search for a user.
 
         Args:
-            user (UserID)
-            content (dict): Search parameters
-            batch (str): The next_batch parameter. Used for pagination.
+            user
+            content: Search parameters
+            batch: The next_batch parameter. Used for pagination.
 
         Returns:
             dict to be returned to the client with results of search
@@ -186,7 +193,7 @@ class SearchHandler(BaseHandler):
         # If doing a subset of all rooms seearch, check if any of the rooms
         # are from an upgraded room, and search their contents as well
         if search_filter.rooms:
-            historical_room_ids = []
+            historical_room_ids = []  # type: List[str]
             for room_id in search_filter.rooms:
                 # Add any previous rooms to the search if they exist
                 ids = await self.get_old_rooms_from_upgraded_room(room_id)
@@ -209,8 +216,10 @@ class SearchHandler(BaseHandler):
 
         rank_map = {}  # event_id -> rank of event
         allowed_events = []
-        room_groups = {}  # Holds result of grouping by room, if applicable
-        sender_group = {}  # Holds result of grouping by sender, if applicable
+        # Holds result of grouping by room, if applicable
+        room_groups = {}  # type: Dict[str, JsonDict]
+        # Holds result of grouping by sender, if applicable
+        sender_group = {}  # type: Dict[str, JsonDict]
 
         # Holds the next_batch for the entire result set if one of those exists
         global_next_batch = None
@@ -254,7 +263,7 @@ class SearchHandler(BaseHandler):
                 s["results"].append(e.event_id)
 
         elif order_by == "recent":
-            room_events = []
+            room_events = []  # type: List[EventBase]
             i = 0
 
             pagination_token = batch_token
@@ -418,13 +427,10 @@ class SearchHandler(BaseHandler):
 
         state_results = {}
         if include_state:
-            rooms = {e.room_id for e in allowed_events}
-            for room_id in rooms:
+            for room_id in {e.room_id for e in allowed_events}:
                 state = await self.state_handler.get_current_state(room_id)
                 state_results[room_id] = list(state.values())
 
-            state_results.values()
-
         # We're now about to serialize the events. We should not make any
         # blocking calls after this. Otherwise the 'age' will be wrong
 
@@ -448,9 +454,9 @@ class SearchHandler(BaseHandler):
 
         if state_results:
             s = {}
-            for room_id, state in state_results.items():
+            for room_id, state_events in state_results.items():
                 s[room_id] = await self._event_serializer.serialize_events(
-                    state, time_now
+                    state_events, time_now
                 )
 
             rooms_cat_res["state"] = s
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index a5d67f828f..84af2dde7e 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -13,24 +13,26 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
 
 from synapse.api.errors import Codes, StoreError, SynapseError
 from synapse.types import Requester
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class SetPasswordHandler(BaseHandler):
     """Handler which deals with changing user account passwords"""
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self._auth_handler = hs.get_auth_handler()
         self._device_handler = hs.get_device_handler()
-        self._password_policy_handler = hs.get_password_policy_handler()
 
     async def set_password(
         self,
@@ -38,7 +40,7 @@ class SetPasswordHandler(BaseHandler):
         password_hash: str,
         logout_devices: bool,
         requester: Optional[Requester] = None,
-    ):
+    ) -> None:
         if not self.hs.config.password_localdb_enabled:
             raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
 
diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py
index fb4f70e8e2..b3f9875358 100644
--- a/synapse/handlers/state_deltas.py
+++ b/synapse/handlers/state_deltas.py
@@ -14,15 +14,25 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Optional
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
 
 class StateDeltasHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
 
-    async def _get_key_change(self, prev_event_id, event_id, key_name, public_value):
+    async def _get_key_change(
+        self,
+        prev_event_id: Optional[str],
+        event_id: Optional[str],
+        key_name: str,
+        public_value: str,
+    ) -> Optional[bool]:
         """Given two events check if the `key_name` field in content changed
         from not matching `public_value` to doing so.
 
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index dc62b21c06..d261d7cd4e 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -12,13 +12,19 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 from collections import Counter
+from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple
+
+from typing_extensions import Counter as CounterType
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.metrics import event_processing_positions
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -31,7 +37,7 @@ class StatsHandler:
     Heavily derived from UserDirectoryHandler
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastore()
         self.state = hs.get_state_handler()
@@ -44,7 +50,7 @@ class StatsHandler:
         self.stats_enabled = hs.config.stats_enabled
 
         # The current position in the current_state_delta stream
-        self.pos = None
+        self.pos = None  # type: Optional[int]
 
         # Guard to ensure we only process deltas one at a time
         self._is_processing = False
@@ -56,7 +62,7 @@ class StatsHandler:
             # we start populating stats
             self.clock.call_later(0, self.notify_new_event)
 
-    def notify_new_event(self):
+    def notify_new_event(self) -> None:
         """Called when there may be more deltas to process
         """
         if not self.stats_enabled or self._is_processing:
@@ -72,7 +78,7 @@ class StatsHandler:
 
         run_as_background_process("stats.notify_new_event", process)
 
-    async def _unsafe_process(self):
+    async def _unsafe_process(self) -> None:
         # If self.pos is None then means we haven't fetched it from DB
         if self.pos is None:
             self.pos = await self.store.get_stats_positions()
@@ -110,10 +116,10 @@ class StatsHandler:
             )
 
             for room_id, fields in room_count.items():
-                room_deltas.setdefault(room_id, {}).update(fields)
+                room_deltas.setdefault(room_id, Counter()).update(fields)
 
             for user_id, fields in user_count.items():
-                user_deltas.setdefault(user_id, {}).update(fields)
+                user_deltas.setdefault(user_id, Counter()).update(fields)
 
             logger.debug("room_deltas: %s", room_deltas)
             logger.debug("user_deltas: %s", user_deltas)
@@ -131,19 +137,20 @@ class StatsHandler:
 
             self.pos = max_pos
 
-    async def _handle_deltas(self, deltas):
+    async def _handle_deltas(
+        self, deltas: Iterable[JsonDict]
+    ) -> Tuple[Dict[str, CounterType[str]], Dict[str, CounterType[str]]]:
         """Called with the state deltas to process
 
         Returns:
-            tuple[dict[str, Counter], dict[str, counter]]
             Two dicts: the room deltas and the user deltas,
             mapping from room/user ID to changes in the various fields.
         """
 
-        room_to_stats_deltas = {}
-        user_to_stats_deltas = {}
+        room_to_stats_deltas = {}  # type: Dict[str, CounterType[str]]
+        user_to_stats_deltas = {}  # type: Dict[str, CounterType[str]]
 
-        room_to_state_updates = {}
+        room_to_state_updates = {}  # type: Dict[str, Dict[str, Any]]
 
         for delta in deltas:
             typ = delta["type"]
@@ -173,7 +180,7 @@ class StatsHandler:
                 )
                 continue
 
-            event_content = {}
+            event_content = {}  # type: JsonDict
 
             sender = None
             if event_id is not None:
@@ -257,13 +264,13 @@ class StatsHandler:
                     )
 
                     if has_changed_joinedness:
-                        delta = +1 if membership == Membership.JOIN else -1
+                        membership_delta = +1 if membership == Membership.JOIN else -1
 
                         user_to_stats_deltas.setdefault(user_id, Counter())[
                             "joined_rooms"
-                        ] += delta
+                        ] += membership_delta
 
-                        room_stats_delta["local_users_in_room"] += delta
+                        room_stats_delta["local_users_in_room"] += membership_delta
 
             elif typ == EventTypes.Create:
                 room_state["is_federatable"] = (
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index e919a8f9ed..3f0dfc7a74 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -15,13 +15,13 @@
 import logging
 import random
 from collections import namedtuple
-from typing import TYPE_CHECKING, List, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
 
 from synapse.api.errors import AuthError, ShadowBanError, SynapseError
 from synapse.appservice import ApplicationService
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.streams import TypingStream
-from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.metrics import Measure
 from synapse.util.wheel_timer import WheelTimer
@@ -65,17 +65,17 @@ class FollowerTypingHandler:
             )
 
         # map room IDs to serial numbers
-        self._room_serials = {}
+        self._room_serials = {}  # type: Dict[str, int]
         # map room IDs to sets of users currently typing
-        self._room_typing = {}
+        self._room_typing = {}  # type: Dict[str, Set[str]]
 
-        self._member_last_federation_poke = {}
+        self._member_last_federation_poke = {}  # type: Dict[RoomMember, int]
         self.wheel_timer = WheelTimer(bucket_size=5000)
         self._latest_room_serial = 0
 
         self.clock.looping_call(self._handle_timeouts, 5000)
 
-    def _reset(self):
+    def _reset(self) -> None:
         """Reset the typing handler's data caches.
         """
         # map room IDs to serial numbers
@@ -86,7 +86,7 @@ class FollowerTypingHandler:
         self._member_last_federation_poke = {}
         self.wheel_timer = WheelTimer(bucket_size=5000)
 
-    def _handle_timeouts(self):
+    def _handle_timeouts(self) -> None:
         logger.debug("Checking for typing timeouts")
 
         now = self.clock.time_msec()
@@ -96,7 +96,7 @@ class FollowerTypingHandler:
         for member in members:
             self._handle_timeout_for_member(now, member)
 
-    def _handle_timeout_for_member(self, now: int, member: RoomMember):
+    def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
         if not self.is_typing(member):
             # Nothing to do if they're no longer typing
             return
@@ -114,10 +114,10 @@ class FollowerTypingHandler:
         # each person typing.
         self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
 
-    def is_typing(self, member):
+    def is_typing(self, member: RoomMember) -> bool:
         return member.user_id in self._room_typing.get(member.room_id, [])
 
-    async def _push_remote(self, member, typing):
+    async def _push_remote(self, member: RoomMember, typing: bool) -> None:
         if not self.federation:
             return
 
@@ -148,7 +148,7 @@ class FollowerTypingHandler:
 
     def process_replication_rows(
         self, token: int, rows: List[TypingStream.TypingStreamRow]
-    ):
+    ) -> None:
         """Should be called whenever we receive updates for typing stream.
         """
 
@@ -178,7 +178,7 @@ class FollowerTypingHandler:
 
     async def _send_changes_in_typing_to_remotes(
         self, room_id: str, prev_typing: Set[str], now_typing: Set[str]
-    ):
+    ) -> None:
         """Process a change in typing of a room from replication, sending EDUs
         for any local users.
         """
@@ -194,12 +194,12 @@ class FollowerTypingHandler:
             if self.is_mine_id(user_id):
                 await self._push_remote(RoomMember(room_id, user_id), False)
 
-    def get_current_token(self):
+    def get_current_token(self) -> int:
         return self._latest_room_serial
 
 
 class TypingWriterHandler(FollowerTypingHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         assert hs.config.worker.writers.typing == hs.get_instance_name()
@@ -213,14 +213,15 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         hs.get_distributor().observe("user_left_room", self.user_left_room)
 
-        self._member_typing_until = {}  # clock time we expect to stop
+        # clock time we expect to stop
+        self._member_typing_until = {}  # type: Dict[RoomMember, int]
 
         # caches which room_ids changed at which serials
         self._typing_stream_change_cache = StreamChangeCache(
             "TypingStreamChangeCache", self._latest_room_serial
         )
 
-    def _handle_timeout_for_member(self, now: int, member: RoomMember):
+    def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
         super()._handle_timeout_for_member(now, member)
 
         if not self.is_typing(member):
@@ -233,7 +234,9 @@ class TypingWriterHandler(FollowerTypingHandler):
             self._stopped_typing(member)
             return
 
-    async def started_typing(self, target_user, requester, room_id, timeout):
+    async def started_typing(
+        self, target_user: UserID, requester: Requester, room_id: str, timeout: int
+    ) -> None:
         target_user_id = target_user.to_string()
         auth_user_id = requester.user.to_string()
 
@@ -263,11 +266,13 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         if was_present:
             # No point sending another notification
-            return None
+            return
 
         self._push_update(member=member, typing=True)
 
-    async def stopped_typing(self, target_user, requester, room_id):
+    async def stopped_typing(
+        self, target_user: UserID, requester: Requester, room_id: str
+    ) -> None:
         target_user_id = target_user.to_string()
         auth_user_id = requester.user.to_string()
 
@@ -290,23 +295,23 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         self._stopped_typing(member)
 
-    def user_left_room(self, user, room_id):
+    def user_left_room(self, user: UserID, room_id: str) -> None:
         user_id = user.to_string()
         if self.is_mine_id(user_id):
             member = RoomMember(room_id=room_id, user_id=user_id)
             self._stopped_typing(member)
 
-    def _stopped_typing(self, member):
+    def _stopped_typing(self, member: RoomMember) -> None:
         if member.user_id not in self._room_typing.get(member.room_id, set()):
             # No point
-            return None
+            return
 
         self._member_typing_until.pop(member, None)
         self._member_last_federation_poke.pop(member, None)
 
         self._push_update(member=member, typing=False)
 
-    def _push_update(self, member, typing):
+    def _push_update(self, member: RoomMember, typing: bool) -> None:
         if self.hs.is_mine_id(member.user_id):
             # Only send updates for changes to our own users.
             run_as_background_process(
@@ -315,7 +320,7 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         self._push_update_local(member=member, typing=typing)
 
-    async def _recv_edu(self, origin, content):
+    async def _recv_edu(self, origin: str, content: JsonDict) -> None:
         room_id = content["room_id"]
         user_id = content["user_id"]
 
@@ -340,7 +345,7 @@ class TypingWriterHandler(FollowerTypingHandler):
             self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT)
             self._push_update_local(member=member, typing=content["typing"])
 
-    def _push_update_local(self, member, typing):
+    def _push_update_local(self, member: RoomMember, typing: bool) -> None:
         room_set = self._room_typing.setdefault(member.room_id, set())
         if typing:
             room_set.add(member.user_id)
@@ -386,7 +391,7 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
             last_id
-        )
+        )  # type: Optional[Iterable[str]]
 
         if changed_rooms is None:
             changed_rooms = self._room_serials
@@ -412,13 +417,13 @@ class TypingWriterHandler(FollowerTypingHandler):
 
     def process_replication_rows(
         self, token: int, rows: List[TypingStream.TypingStreamRow]
-    ):
+    ) -> None:
         # The writing process should never get updates from replication.
         raise Exception("Typing writer instance got typing info over replication")
 
 
 class TypingNotificationEventSource:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.clock = hs.get_clock()
         # We can't call get_typing_handler here because there's a cycle:
@@ -427,7 +432,7 @@ class TypingNotificationEventSource:
         #
         self.get_typing_handler = hs.get_typing_handler
 
-    def _make_event_for(self, room_id):
+    def _make_event_for(self, room_id: str) -> JsonDict:
         typing = self.get_typing_handler()._room_typing[room_id]
         return {
             "type": "m.typing",
@@ -462,7 +467,9 @@ class TypingNotificationEventSource:
 
             return (events, handler._latest_room_serial)
 
-    async def get_new_events(self, from_key, room_ids, **kwargs):
+    async def get_new_events(
+        self, from_key: int, room_ids: Iterable[str], **kwargs
+    ) -> Tuple[List[JsonDict], int]:
         with Measure(self.clock, "typing.get_new_events"):
             from_key = int(from_key)
             handler = self.get_typing_handler()
@@ -478,5 +485,5 @@ class TypingNotificationEventSource:
 
             return (events, handler._latest_room_serial)
 
-    def get_current_key(self):
+    def get_current_key(self) -> int:
         return self.get_typing_handler()._latest_room_serial
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index d4651c8348..8aedf5072e 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -145,10 +145,6 @@ class UserDirectoryHandler(StateDeltasHandler):
         if self.pos is None:
             self.pos = await self.store.get_user_directory_stream_pos()
 
-        # If still None then the initial background update hasn't happened yet
-        if self.pos is None:
-            return None
-
         # Loop round handling deltas until we're up to date
         while True:
             with Measure(self.clock, "user_dir_delta"):
@@ -233,6 +229,11 @@ class UserDirectoryHandler(StateDeltasHandler):
 
                     if change:  # The user joined
                         event = await self.store.get_event(event_id, allow_none=True)
+                        # It isn't expected for this event to not exist, but we
+                        # don't want the entire background process to break.
+                        if event is None:
+                            continue
+
                         profile = ProfileInfo(
                             avatar_url=event.content.get("avatar_url"),
                             display_name=event.content.get("displayname"),