diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/appservice/scheduler.py | 18 | ||||
-rw-r--r-- | synapse/handlers/_base.py | 20 | ||||
-rw-r--r-- | synapse/handlers/initial_sync.py | 8 | ||||
-rw-r--r-- | synapse/handlers/profile.py | 74 | ||||
-rw-r--r-- | synapse/logging/_remote.py | 225 | ||||
-rw-r--r-- | synapse/logging/_terse_json.py | 199 | ||||
-rw-r--r-- | synapse/storage/databases/main/profile.py | 6 |
7 files changed, 322 insertions, 228 deletions
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index ad3c408519..58291afc22 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -60,6 +60,13 @@ from synapse.types import JsonDict logger = logging.getLogger(__name__) +# Maximum number of events to provide in an AS transaction. +MAX_PERSISTENT_EVENTS_PER_TRANSACTION = 100 + +# Maximum number of ephemeral events to provide in an AS transaction. +MAX_EPHEMERAL_EVENTS_PER_TRANSACTION = 100 + + class ApplicationServiceScheduler: """ Public facing API for this module. Does the required DI to tie the components together. This also serves as the "event_pool", which in this @@ -136,10 +143,17 @@ class _ServiceQueuer: self.requests_in_flight.add(service.id) try: while True: - events = self.queued_events.pop(service.id, []) - ephemeral = self.queued_ephemeral.pop(service.id, []) + all_events = self.queued_events.get(service.id, []) + events = all_events[:MAX_PERSISTENT_EVENTS_PER_TRANSACTION] + del all_events[:MAX_PERSISTENT_EVENTS_PER_TRANSACTION] + + all_events_ephemeral = self.queued_ephemeral.get(service.id, []) + ephemeral = all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION] + del all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION] + if not events and not ephemeral: return + try: await self.txn_ctrl.send(service, events, ephemeral) except Exception: diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 0206320e96..bd8e71ae56 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Optional import synapse.state import synapse.storage @@ -22,6 +23,9 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.ratelimiting import Ratelimiter from synapse.types import UserID +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -30,11 +34,7 @@ class BaseHandler: Common base class for the event handlers. """ - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): - """ + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() # type: synapse.storage.DataStore self.auth = hs.get_auth() self.notifier = hs.get_notifier() @@ -56,7 +56,7 @@ class BaseHandler: clock=self.clock, rate_hz=self.hs.config.rc_admin_redaction.per_second, burst_count=self.hs.config.rc_admin_redaction.burst_count, - ) + ) # type: Optional[Ratelimiter] else: self.admin_redaction_ratelimiter = None @@ -127,15 +127,15 @@ class BaseHandler: if guest_access != "can_join": if context: current_state_ids = await context.get_current_state_ids() - current_state = await self.store.get_events( + current_state_dict = await self.store.get_events( list(current_state_ids.values()) ) + current_state = list(current_state_dict.values()) else: - current_state = await self.state_handler.get_current_state( + current_state_map = await self.state_handler.get_current_state( event.room_id ) - - current_state = list(current_state.values()) + current_state = list(current_state_map.values()) logger.info("maybe_kick_guest_users %r", current_state) await self.kick_guest_users(current_state) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 98075f48d2..cb11754bf8 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -293,6 +293,10 @@ class InitialSyncHandler(BaseHandler): user_id, room_id, pagin_config, membership, is_peeking ) elif membership == Membership.LEAVE: + # The member_event_id will always be available if membership is set + # to leave. + assert member_event_id + result = await self._room_initial_sync_parted( user_id, room_id, pagin_config, membership, member_event_id, is_peeking ) @@ -315,7 +319,7 @@ class InitialSyncHandler(BaseHandler): user_id: str, room_id: str, pagin_config: PaginationConfig, - membership: Membership, + membership: str, member_event_id: str, is_peeking: bool, ) -> JsonDict: @@ -367,7 +371,7 @@ class InitialSyncHandler(BaseHandler): user_id: str, room_id: str, pagin_config: PaginationConfig, - membership: Membership, + membership: str, is_peeking: bool, ) -> JsonDict: current_state = await self.state.get_current_state(room_id=room_id) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index b78a12ad01..92700b589c 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -12,9 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import random +from typing import TYPE_CHECKING, Optional from synapse.api.errors import ( AuthError, @@ -25,10 +25,19 @@ from synapse.api.errors import ( SynapseError, ) from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.types import UserID, create_requester, get_domain_from_id +from synapse.types import ( + JsonDict, + Requester, + UserID, + create_requester, + get_domain_from_id, +) from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) MAX_DISPLAYNAME_LEN = 256 @@ -45,7 +54,7 @@ class ProfileHandler(BaseHandler): PROFILE_UPDATE_MS = 60 * 1000 PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.federation = hs.get_federation_client() @@ -60,7 +69,7 @@ class ProfileHandler(BaseHandler): self._update_remote_profile_cache, self.PROFILE_UPDATE_MS ) - async def get_profile(self, user_id): + async def get_profile(self, user_id: str) -> JsonDict: target_user = UserID.from_string(user_id) if self.hs.is_mine(target_user): @@ -91,7 +100,7 @@ class ProfileHandler(BaseHandler): except HttpResponseException as e: raise e.to_synapse_error() - async def get_profile_from_cache(self, user_id): + async def get_profile_from_cache(self, user_id: str) -> JsonDict: """Get the profile information from our local cache. If the user is ours then the profile information will always be corect. Otherwise, it may be out of date/missing. @@ -115,7 +124,7 @@ class ProfileHandler(BaseHandler): profile = await self.store.get_from_remote_profile_cache(user_id) return profile or {} - async def get_displayname(self, target_user): + async def get_displayname(self, target_user: UserID) -> str: if self.hs.is_mine(target_user): try: displayname = await self.store.get_profile_displayname( @@ -143,15 +152,19 @@ class ProfileHandler(BaseHandler): return result["displayname"] async def set_displayname( - self, target_user, requester, new_displayname, by_admin=False - ): + self, + target_user: UserID, + requester: Requester, + new_displayname: str, + by_admin: bool = False, + ) -> None: """Set the displayname of a user Args: - target_user (UserID): the user whose displayname is to be changed. - requester (Requester): The user attempting to make this change. - new_displayname (str): The displayname to give this user. - by_admin (bool): Whether this change was made by an administrator. + target_user: the user whose displayname is to be changed. + requester: The user attempting to make this change. + new_displayname: The displayname to give this user. + by_admin: Whether this change was made by an administrator. """ if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this homeserver") @@ -176,8 +189,9 @@ class ProfileHandler(BaseHandler): 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,) ) + displayname_to_set = new_displayname # type: Optional[str] if new_displayname == "": - new_displayname = None + displayname_to_set = None # If the admin changes the display name of a user, the requesting user cannot send # the join event to update the displayname in the rooms. @@ -185,7 +199,9 @@ class ProfileHandler(BaseHandler): if by_admin: requester = create_requester(target_user) - await self.store.set_profile_displayname(target_user.localpart, new_displayname) + await self.store.set_profile_displayname( + target_user.localpart, displayname_to_set + ) if self.hs.config.user_directory_search_all_users: profile = await self.store.get_profileinfo(target_user.localpart) @@ -195,7 +211,7 @@ class ProfileHandler(BaseHandler): await self._update_join_states(requester, target_user) - async def get_avatar_url(self, target_user): + async def get_avatar_url(self, target_user: UserID) -> str: if self.hs.is_mine(target_user): try: avatar_url = await self.store.get_profile_avatar_url( @@ -222,15 +238,19 @@ class ProfileHandler(BaseHandler): return result["avatar_url"] async def set_avatar_url( - self, target_user, requester, new_avatar_url, by_admin=False + self, + target_user: UserID, + requester: Requester, + new_avatar_url: str, + by_admin: bool = False, ): """Set a new avatar URL for a user. Args: - target_user (UserID): the user whose avatar URL is to be changed. - requester (Requester): The user attempting to make this change. - new_avatar_url (str): The avatar URL to give this user. - by_admin (bool): Whether this change was made by an administrator. + target_user: the user whose avatar URL is to be changed. + requester: The user attempting to make this change. + new_avatar_url: The avatar URL to give this user. + by_admin: Whether this change was made by an administrator. """ if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this homeserver") @@ -267,7 +287,7 @@ class ProfileHandler(BaseHandler): await self._update_join_states(requester, target_user) - async def on_profile_query(self, args): + async def on_profile_query(self, args: JsonDict) -> JsonDict: user = UserID.from_string(args["user_id"]) if not self.hs.is_mine(user): raise SynapseError(400, "User is not hosted on this homeserver") @@ -292,7 +312,9 @@ class ProfileHandler(BaseHandler): return response - async def _update_join_states(self, requester, target_user): + async def _update_join_states( + self, requester: Requester, target_user: UserID + ) -> None: if not self.hs.is_mine(target_user): return @@ -323,15 +345,17 @@ class ProfileHandler(BaseHandler): "Failed to update join event for room %s - %s", room_id, str(e) ) - async def check_profile_query_allowed(self, target_user, requester=None): + async def check_profile_query_allowed( + self, target_user: UserID, requester: Optional[UserID] = None + ) -> None: """Checks whether a profile query is allowed. If the 'require_auth_for_profile_requests' config flag is set to True and a 'requester' is provided, the query is only allowed if the two users share a room. Args: - target_user (UserID): The owner of the queried profile. - requester (None|UserID): The user querying for the profile. + target_user: The owner of the queried profile. + requester: The user querying for the profile. Raises: SynapseError(403): The two users share no room, or ne user couldn't diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py new file mode 100644 index 0000000000..0caf325916 --- /dev/null +++ b/synapse/logging/_remote.py @@ -0,0 +1,225 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import traceback +from collections import deque +from ipaddress import IPv4Address, IPv6Address, ip_address +from math import floor +from typing import Callable, Optional + +import attr +from zope.interface import implementer + +from twisted.application.internet import ClientService +from twisted.internet.defer import Deferred +from twisted.internet.endpoints import ( + HostnameEndpoint, + TCP4ClientEndpoint, + TCP6ClientEndpoint, +) +from twisted.internet.interfaces import IPushProducer, ITransport +from twisted.internet.protocol import Factory, Protocol +from twisted.logger import ILogObserver, Logger, LogLevel + + +@attr.s +@implementer(IPushProducer) +class LogProducer: + """ + An IPushProducer that writes logs from its buffer to its transport when it + is resumed. + + Args: + buffer: Log buffer to read logs from. + transport: Transport to write to. + format_event: A callable to format the log entry to a string. + """ + + transport = attr.ib(type=ITransport) + format_event = attr.ib(type=Callable[[dict], str]) + _buffer = attr.ib(type=deque) + _paused = attr.ib(default=False, type=bool, init=False) + + def pauseProducing(self): + self._paused = True + + def stopProducing(self): + self._paused = True + self._buffer = deque() + + def resumeProducing(self): + self._paused = False + + while self._paused is False and (self._buffer and self.transport.connected): + try: + # Request the next event and format it. + event = self._buffer.popleft() + msg = self.format_event(event) + + # Send it as a new line over the transport. + self.transport.write(msg.encode("utf8")) + except Exception: + # Something has gone wrong writing to the transport -- log it + # and break out of the while. + traceback.print_exc(file=sys.__stderr__) + break + + +@attr.s +@implementer(ILogObserver) +class TCPLogObserver: + """ + An IObserver that writes JSON logs to a TCP target. + + Args: + hs (HomeServer): The homeserver that is being logged for. + host: The host of the logging target. + port: The logging target's port. + format_event: A callable to format the log entry to a string. + maximum_buffer: The maximum buffer size. + """ + + hs = attr.ib() + host = attr.ib(type=str) + port = attr.ib(type=int) + format_event = attr.ib(type=Callable[[dict], str]) + maximum_buffer = attr.ib(type=int) + _buffer = attr.ib(default=attr.Factory(deque), type=deque) + _connection_waiter = attr.ib(default=None, type=Optional[Deferred]) + _logger = attr.ib(default=attr.Factory(Logger)) + _producer = attr.ib(default=None, type=Optional[LogProducer]) + + def start(self) -> None: + + # Connect without DNS lookups if it's a direct IP. + try: + ip = ip_address(self.host) + if isinstance(ip, IPv4Address): + endpoint = TCP4ClientEndpoint( + self.hs.get_reactor(), self.host, self.port + ) + elif isinstance(ip, IPv6Address): + endpoint = TCP6ClientEndpoint( + self.hs.get_reactor(), self.host, self.port + ) + else: + raise ValueError("Unknown IP address provided: %s" % (self.host,)) + except ValueError: + endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port) + + factory = Factory.forProtocol(Protocol) + self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor()) + self._service.startService() + self._connect() + + def stop(self): + self._service.stopService() + + def _connect(self) -> None: + """ + Triggers an attempt to connect then write to the remote if not already writing. + """ + if self._connection_waiter: + return + + self._connection_waiter = self._service.whenConnected(failAfterFailures=1) + + @self._connection_waiter.addErrback + def fail(r): + r.printTraceback(file=sys.__stderr__) + self._connection_waiter = None + self._connect() + + @self._connection_waiter.addCallback + def writer(r): + # We have a connection. If we already have a producer, and its + # transport is the same, just trigger a resumeProducing. + if self._producer and r.transport is self._producer.transport: + self._producer.resumeProducing() + self._connection_waiter = None + return + + # If the producer is still producing, stop it. + if self._producer: + self._producer.stopProducing() + + # Make a new producer and start it. + self._producer = LogProducer( + buffer=self._buffer, + transport=r.transport, + format_event=self.format_event, + ) + r.transport.registerProducer(self._producer, True) + self._producer.resumeProducing() + self._connection_waiter = None + + def _handle_pressure(self) -> None: + """ + Handle backpressure by shedding events. + + The buffer will, in this order, until the buffer is below the maximum: + - Shed DEBUG events + - Shed INFO events + - Shed the middle 50% of the events. + """ + if len(self._buffer) <= self.maximum_buffer: + return + + # Strip out DEBUGs + self._buffer = deque( + filter(lambda event: event["log_level"] != LogLevel.debug, self._buffer) + ) + + if len(self._buffer) <= self.maximum_buffer: + return + + # Strip out INFOs + self._buffer = deque( + filter(lambda event: event["log_level"] != LogLevel.info, self._buffer) + ) + + if len(self._buffer) <= self.maximum_buffer: + return + + # Cut the middle entries out + buffer_split = floor(self.maximum_buffer / 2) + + old_buffer = self._buffer + self._buffer = deque() + + for i in range(buffer_split): + self._buffer.append(old_buffer.popleft()) + + end_buffer = [] + for i in range(buffer_split): + end_buffer.append(old_buffer.pop()) + + self._buffer.extend(reversed(end_buffer)) + + def __call__(self, event: dict) -> None: + self._buffer.append(event) + + # Handle backpressure, if it exists. + try: + self._handle_pressure() + except Exception: + # If handling backpressure fails,clear the buffer and log the + # exception. + self._buffer.clear() + self._logger.failure("Failed clearing backpressure") + + # Try and write immediately. + self._connect() diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py index 1b8916cfa2..9b46956ca9 100644 --- a/synapse/logging/_terse_json.py +++ b/synapse/logging/_terse_json.py @@ -18,26 +18,11 @@ Log formatters that output terse JSON. """ import json -import sys -import traceback -from collections import deque -from ipaddress import IPv4Address, IPv6Address, ip_address -from math import floor -from typing import IO, Optional +from typing import IO -import attr -from zope.interface import implementer +from twisted.logger import FileLogObserver -from twisted.application.internet import ClientService -from twisted.internet.defer import Deferred -from twisted.internet.endpoints import ( - HostnameEndpoint, - TCP4ClientEndpoint, - TCP6ClientEndpoint, -) -from twisted.internet.interfaces import IPushProducer, ITransport -from twisted.internet.protocol import Factory, Protocol -from twisted.logger import FileLogObserver, ILogObserver, Logger +from synapse.logging._remote import TCPLogObserver _encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":")) @@ -150,180 +135,22 @@ def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogOb return FileLogObserver(outFile, formatEvent) -@attr.s -@implementer(IPushProducer) -class LogProducer: +def TerseJSONToTCPLogObserver( + hs, host: str, port: int, metadata: dict, maximum_buffer: int +) -> FileLogObserver: """ - An IPushProducer that writes logs from its buffer to its transport when it - is resumed. - - Args: - buffer: Log buffer to read logs from. - transport: Transport to write to. - """ - - transport = attr.ib(type=ITransport) - _buffer = attr.ib(type=deque) - _paused = attr.ib(default=False, type=bool, init=False) - - def pauseProducing(self): - self._paused = True - - def stopProducing(self): - self._paused = True - self._buffer = deque() - - def resumeProducing(self): - self._paused = False - - while self._paused is False and (self._buffer and self.transport.connected): - try: - event = self._buffer.popleft() - self.transport.write(_encoder.encode(event).encode("utf8")) - self.transport.write(b"\n") - except Exception: - # Something has gone wrong writing to the transport -- log it - # and break out of the while. - traceback.print_exc(file=sys.__stderr__) - break - - -@attr.s -@implementer(ILogObserver) -class TerseJSONToTCPLogObserver: - """ - An IObserver that writes JSON logs to a TCP target. + A log observer that formats events to a flattened JSON representation. Args: hs (HomeServer): The homeserver that is being logged for. host: The host of the logging target. port: The logging target's port. - metadata: Metadata to be added to each log entry. + metadata: Metadata to be added to each log object. + maximum_buffer: The maximum buffer size. """ - hs = attr.ib() - host = attr.ib(type=str) - port = attr.ib(type=int) - metadata = attr.ib(type=dict) - maximum_buffer = attr.ib(type=int) - _buffer = attr.ib(default=attr.Factory(deque), type=deque) - _connection_waiter = attr.ib(default=None, type=Optional[Deferred]) - _logger = attr.ib(default=attr.Factory(Logger)) - _producer = attr.ib(default=None, type=Optional[LogProducer]) - - def start(self) -> None: - - # Connect without DNS lookups if it's a direct IP. - try: - ip = ip_address(self.host) - if isinstance(ip, IPv4Address): - endpoint = TCP4ClientEndpoint( - self.hs.get_reactor(), self.host, self.port - ) - elif isinstance(ip, IPv6Address): - endpoint = TCP6ClientEndpoint( - self.hs.get_reactor(), self.host, self.port - ) - except ValueError: - endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port) - - factory = Factory.forProtocol(Protocol) - self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor()) - self._service.startService() - self._connect() - - def stop(self): - self._service.stopService() - - def _connect(self) -> None: - """ - Triggers an attempt to connect then write to the remote if not already writing. - """ - if self._connection_waiter: - return - - self._connection_waiter = self._service.whenConnected(failAfterFailures=1) - - @self._connection_waiter.addErrback - def fail(r): - r.printTraceback(file=sys.__stderr__) - self._connection_waiter = None - self._connect() - - @self._connection_waiter.addCallback - def writer(r): - # We have a connection. If we already have a producer, and its - # transport is the same, just trigger a resumeProducing. - if self._producer and r.transport is self._producer.transport: - self._producer.resumeProducing() - self._connection_waiter = None - return - - # If the producer is still producing, stop it. - if self._producer: - self._producer.stopProducing() - - # Make a new producer and start it. - self._producer = LogProducer(buffer=self._buffer, transport=r.transport) - r.transport.registerProducer(self._producer, True) - self._producer.resumeProducing() - self._connection_waiter = None - - def _handle_pressure(self) -> None: - """ - Handle backpressure by shedding events. - - The buffer will, in this order, until the buffer is below the maximum: - - Shed DEBUG events - - Shed INFO events - - Shed the middle 50% of the events. - """ - if len(self._buffer) <= self.maximum_buffer: - return - - # Strip out DEBUGs - self._buffer = deque( - filter(lambda event: event["level"] != "DEBUG", self._buffer) - ) - - if len(self._buffer) <= self.maximum_buffer: - return - - # Strip out INFOs - self._buffer = deque( - filter(lambda event: event["level"] != "INFO", self._buffer) - ) - - if len(self._buffer) <= self.maximum_buffer: - return - - # Cut the middle entries out - buffer_split = floor(self.maximum_buffer / 2) - - old_buffer = self._buffer - self._buffer = deque() - - for i in range(buffer_split): - self._buffer.append(old_buffer.popleft()) - - end_buffer = [] - for i in range(buffer_split): - end_buffer.append(old_buffer.pop()) - - self._buffer.extend(reversed(end_buffer)) - - def __call__(self, event: dict) -> None: - flattened = flatten_event(event, self.metadata, include_time=True) - self._buffer.append(flattened) - - # Handle backpressure, if it exists. - try: - self._handle_pressure() - except Exception: - # If handling backpressure fails,clear the buffer and log the - # exception. - self._buffer.clear() - self._logger.failure("Failed clearing backpressure") + def formatEvent(_event: dict) -> str: + flattened = flatten_event(_event, metadata, include_time=True) + return _encoder.encode(flattened) + "\n" - # Try and write immediately. - self._connect() + return TCPLogObserver(hs, host, port, formatEvent, maximum_buffer) diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 1681caa1f0..a6d1eb908a 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore @@ -72,7 +72,7 @@ class ProfileWorkerStore(SQLBaseStore): ) async def set_profile_displayname( - self, user_localpart: str, new_displayname: str + self, user_localpart: str, new_displayname: Optional[str] ) -> None: await self.db_pool.simple_update_one( table="profiles", @@ -144,7 +144,7 @@ class ProfileWorkerStore(SQLBaseStore): async def get_remote_profile_cache_entries_that_expire( self, last_checked: int - ) -> Dict[str, str]: + ) -> List[Dict[str, str]]: """Get all users who haven't been checked since `last_checked` """ |