diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 3ed29a2c16..9fc8444228 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -12,9 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
-from typing import Dict, List, Optional, Union
+from typing import TYPE_CHECKING, Dict, List, Optional, Union
from prometheus_client import Counter
@@ -34,16 +33,20 @@ from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
-from synapse.types import Collection, JsonDict, RoomStreamToken, UserID
+from synapse.storage.databases.main.directory import RoomAliasMapping
+from synapse.types import Collection, JsonDict, RoomAlias, RoomStreamToken, UserID
from synapse.util.metrics import Measure
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "")
class ApplicationServicesHandler:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.is_mine_id = hs.is_mine_id
self.appservice_api = hs.get_application_service_api()
@@ -247,7 +250,9 @@ class ApplicationServicesHandler:
service, "presence", new_token
)
- async def _handle_typing(self, service: ApplicationService, new_token: int):
+ async def _handle_typing(
+ self, service: ApplicationService, new_token: int
+ ) -> List[JsonDict]:
typing_source = self.event_sources.sources["typing"]
# Get the typing events from just before current
typing, _ = await typing_source.get_new_events_as(
@@ -259,7 +264,7 @@ class ApplicationServicesHandler:
)
return typing
- async def _handle_receipts(self, service: ApplicationService):
+ async def _handle_receipts(self, service: ApplicationService) -> List[JsonDict]:
from_key = await self.store.get_type_stream_id_for_appservice(
service, "read_receipt"
)
@@ -271,7 +276,7 @@ class ApplicationServicesHandler:
async def _handle_presence(
self, service: ApplicationService, users: Collection[Union[str, UserID]]
- ):
+ ) -> List[JsonDict]:
events = [] # type: List[JsonDict]
presence_source = self.event_sources.sources["presence"]
from_key = await self.store.get_type_stream_id_for_appservice(
@@ -301,11 +306,11 @@ class ApplicationServicesHandler:
return events
- async def query_user_exists(self, user_id):
+ async def query_user_exists(self, user_id: str) -> bool:
"""Check if any application service knows this user_id exists.
Args:
- user_id(str): The user to query if they exist on any AS.
+ user_id: The user to query if they exist on any AS.
Returns:
True if this user exists on at least one application service.
"""
@@ -316,11 +321,13 @@ class ApplicationServicesHandler:
return True
return False
- async def query_room_alias_exists(self, room_alias):
+ async def query_room_alias_exists(
+ self, room_alias: RoomAlias
+ ) -> Optional[RoomAliasMapping]:
"""Check if an application service knows this room alias exists.
Args:
- room_alias(RoomAlias): The room alias to query.
+ room_alias: The room alias to query.
Returns:
namedtuple: with keys "room_id" and "servers" or None if no
association can be found.
@@ -336,10 +343,13 @@ class ApplicationServicesHandler:
)
if is_known_alias:
# the alias exists now so don't query more ASes.
- result = await self.store.get_association_from_room_alias(room_alias)
- return result
+ return await self.store.get_association_from_room_alias(room_alias)
+
+ return None
- async def query_3pe(self, kind, protocol, fields):
+ async def query_3pe(
+ self, kind: str, protocol: str, fields: Dict[bytes, List[bytes]]
+ ) -> List[JsonDict]:
services = self._get_services_for_3pn(protocol)
results = await make_deferred_yieldable(
@@ -361,7 +371,9 @@ class ApplicationServicesHandler:
return ret
- async def get_3pe_protocols(self, only_protocol=None):
+ async def get_3pe_protocols(
+ self, only_protocol: Optional[str] = None
+ ) -> Dict[str, JsonDict]:
services = self.store.get_app_services()
protocols = {} # type: Dict[str, List[JsonDict]]
@@ -379,7 +391,7 @@ class ApplicationServicesHandler:
if info is not None:
protocols[p].append(info)
- def _merge_instances(infos):
+ def _merge_instances(infos: List[JsonDict]) -> JsonDict:
if not infos:
return {}
@@ -394,19 +406,17 @@ class ApplicationServicesHandler:
return combined
- for p in protocols.keys():
- protocols[p] = _merge_instances(protocols[p])
+ return {p: _merge_instances(protocols[p]) for p in protocols.keys()}
- return protocols
-
- async def _get_services_for_event(self, event):
+ async def _get_services_for_event(
+ self, event: EventBase
+ ) -> List[ApplicationService]:
"""Retrieve a list of application services interested in this event.
Args:
- event(Event): The event to check. Can be None if alias_list is not.
+ event: The event to check. Can be None if alias_list is not.
Returns:
- list<ApplicationService>: A list of services interested in this
- event based on the service regex.
+ A list of services interested in this event based on the service regex.
"""
services = self.store.get_app_services()
@@ -420,17 +430,15 @@ class ApplicationServicesHandler:
return interested_list
- def _get_services_for_user(self, user_id):
+ def _get_services_for_user(self, user_id: str) -> List[ApplicationService]:
services = self.store.get_app_services()
- interested_list = [s for s in services if (s.is_interested_in_user(user_id))]
- return interested_list
+ return [s for s in services if (s.is_interested_in_user(user_id))]
- def _get_services_for_3pn(self, protocol):
+ def _get_services_for_3pn(self, protocol: str) -> List[ApplicationService]:
services = self.store.get_app_services()
- interested_list = [s for s in services if s.is_interested_in_protocol(protocol)]
- return interested_list
+ return [s for s in services if s.is_interested_in_protocol(protocol)]
- async def _is_unknown_user(self, user_id):
+ async def _is_unknown_user(self, user_id: str) -> bool:
if not self.is_mine_id(user_id):
# we don't know if they are unknown or not since it isn't one of our
# users. We can't poke ASes.
@@ -445,9 +453,8 @@ class ApplicationServicesHandler:
service_list = [s for s in services if s.sender == user_id]
return len(service_list) == 0
- async def _check_user_exists(self, user_id):
+ async def _check_user_exists(self, user_id: str) -> bool:
unknown_user = await self._is_unknown_user(user_id)
if unknown_user:
- exists = await self.query_user_exists(user_id)
- return exists
+ return await self.query_user_exists(user_id)
return True
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index dd14ab69d7..ff103cbb92 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -18,10 +18,20 @@ import logging
import time
import unicodedata
import urllib.parse
-from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Union,
+)
import attr
-import bcrypt # type: ignore[import]
+import bcrypt
import pymacaroons
from synapse.api.constants import LoginType
@@ -49,6 +59,9 @@ from synapse.util.threepids import canonicalise_email
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
@@ -149,11 +162,7 @@ class SsoLoginExtraAttributes:
class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer):
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker]
@@ -982,17 +991,17 @@ class AuthHandler(BaseHandler):
# This might return an awaitable, if it does block the log out
# until it completes.
result = provider.on_logged_out(
- user_id=str(user_info["user"]),
- device_id=user_info["device_id"],
+ user_id=user_info.user_id,
+ device_id=user_info.device_id,
access_token=access_token,
)
if inspect.isawaitable(result):
await result
# delete pushers associated with this access token
- if user_info["token_id"] is not None:
+ if user_info.token_id is not None:
await self.hs.get_pusherpool().remove_pushers_by_access_token(
- str(user_info["user"]), (user_info["token_id"],)
+ user_info.user_id, (user_info.token_id,)
)
async def delete_access_tokens_for_user(
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 9c0096bae5..31f91e0a1a 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -50,9 +50,8 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
-from synapse.util import json_decoder
+from synapse.util import json_decoder, json_encoder
from synapse.util.async_helpers import Linearizer
-from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client
@@ -928,7 +927,7 @@ class EventCreationHandler:
# Ensure that we can round trip before trying to persist in db
try:
- dump = frozendict_json_encoder.encode(event.content)
+ dump = json_encoder.encode(event.content)
json_decoder.decode(dump)
except Exception:
logger.exception("Failed to encode content: %r", event.content)
@@ -1100,34 +1099,13 @@ class EventCreationHandler:
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE:
-
- def is_inviter_member_event(e):
- return e.type == EventTypes.Member and e.sender == event.sender
-
- current_state_ids = await context.get_current_state_ids()
-
- # We know this event is not an outlier, so this must be
- # non-None.
- assert current_state_ids is not None
-
- state_to_include_ids = [
- e_id
- for k, e_id in current_state_ids.items()
- if k[0] in self.room_invite_state_types
- or k == (EventTypes.Member, event.sender)
- ]
-
- state_to_include = await self.store.get_events(state_to_include_ids)
-
- event.unsigned["invite_room_state"] = [
- {
- "type": e.type,
- "state_key": e.state_key,
- "content": e.content,
- "sender": e.sender,
- }
- for e in state_to_include.values()
- ]
+ event.unsigned[
+ "invite_room_state"
+ ] = await self.store.get_stripped_room_state_from_event_context(
+ context,
+ self.room_invite_state_types,
+ membership_user_id=event.sender,
+ )
invitee = UserID.from_string(event.state_key)
if not self.hs.is_mine(invitee):
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 49a00eed9c..8e014c9bb5 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -48,7 +48,7 @@ from synapse.util.wheel_timer import WheelTimer
MYPY = False
if MYPY:
- import synapse.server
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -101,7 +101,7 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
class BasePresenceHandler(abc.ABC):
"""Parts of the PresenceHandler that are shared between workers and master"""
- def __init__(self, hs: "synapse.server.HomeServer"):
+ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
@@ -199,7 +199,7 @@ class BasePresenceHandler(abc.ABC):
class PresenceHandler(BasePresenceHandler):
- def __init__(self, hs: "synapse.server.HomeServer"):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self.is_mine_id = hs.is_mine_id
@@ -1011,7 +1011,7 @@ def format_user_presence_state(state, now, include_user_id=True):
class PresenceEventSource:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
# We can't call get_presence_handler here because there's a cycle:
#
# Presence -> Notifier -> PresenceEventSource -> Presence
@@ -1071,12 +1071,14 @@ class PresenceEventSource:
users_interested_in = await self._get_interested_in(user, explicit_room_id)
- user_ids_changed = set()
+ user_ids_changed = set() # type: Collection[str]
changed = None
if from_key:
changed = stream_change_cache.get_all_entities_changed(from_key)
if changed is not None and len(changed) < 500:
+ assert isinstance(user_ids_changed, set)
+
# For small deltas, its quicker to get all changes and then
# work out if we share a room or they're in our presence list
get_updates_counter.labels("stream").inc()
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index a6f1d21674..ed1ff62599 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -115,7 +115,10 @@ class RegistrationHandler(BaseHandler):
400, "User ID already taken.", errcode=Codes.USER_IN_USE
)
user_data = await self.auth.get_user_by_access_token(guest_access_token)
- if not user_data["is_guest"] or user_data["user"].localpart != localpart:
+ if (
+ not user_data.is_guest
+ or UserID.from_string(user_data.user_id).localpart != localpart
+ ):
raise AuthError(
403,
"Cannot register taken user ID without valid guest "
@@ -741,7 +744,7 @@ class RegistrationHandler(BaseHandler):
# up when the access token is saved, but that's quite an
# invasive change I'd rather do separately.
user_tuple = await self.store.get_user_by_access_token(token)
- token_id = user_tuple["token_id"]
+ token_id = user_tuple.token_id
await self.pusher_pool.add_pusher(
user_id=user_id,
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index c5b1f1f1e1..e73031475f 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -771,22 +771,29 @@ class RoomCreationHandler(BaseHandler):
ratelimit=False,
)
- for invitee in invite_list:
+ # we avoid dropping the lock between invites, as otherwise joins can
+ # start coming in and making the createRoom slow.
+ #
+ # we also don't need to check the requester's shadow-ban here, as we
+ # have already done so above (and potentially emptied invite_list).
+ with (await self.room_member_handler.member_linearizer.queue((room_id,))):
content = {}
is_direct = config.get("is_direct", None)
if is_direct:
content["is_direct"] = is_direct
- # Note that update_membership with an action of "invite" can raise a
- # ShadowBanError, but this was handled above by emptying invite_list.
- _, last_stream_id = await self.room_member_handler.update_membership(
- requester,
- UserID.from_string(invitee),
- room_id,
- "invite",
- ratelimit=False,
- content=content,
- )
+ for invitee in invite_list:
+ (
+ _,
+ last_stream_id,
+ ) = await self.room_member_handler.update_membership_locked(
+ requester,
+ UserID.from_string(invitee),
+ room_id,
+ "invite",
+ ratelimit=False,
+ content=content,
+ )
for invite_3pid in invite_3pid_list:
id_server = invite_3pid["id_server"]
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 0268288600..7e5e53a56f 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -327,7 +327,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# haproxy would have timed the request out anyway...
raise SynapseError(504, "took to long to process")
- result = await self._update_membership(
+ result = await self.update_membership_locked(
requester,
target,
room_id,
@@ -342,7 +342,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
return result
- async def _update_membership(
+ async def update_membership_locked(
self,
requester: Requester,
target: UserID,
@@ -355,6 +355,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
content: Optional[dict] = None,
require_consent: bool = True,
) -> Tuple[str, int]:
+ """Helper for update_membership.
+
+ Assumes that the membership linearizer is already held for the room.
+ """
content_specified = bool(content)
if content is None:
content = {}
|