diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index c1ade1333b..c5d1eb952b 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -537,8 +537,7 @@ class Auth(object):
return defer.succeed(auth_ids)
- @defer.inlineCallbacks
- def check_can_change_room_list(self, room_id: str, user: UserID):
+ async def check_can_change_room_list(self, room_id: str, user: UserID):
"""Determine whether the user is allowed to edit the room's entry in the
published room list.
@@ -547,17 +546,17 @@ class Auth(object):
user
"""
- is_admin = yield self.is_server_admin(user)
+ is_admin = await self.is_server_admin(user)
if is_admin:
return True
user_id = user.to_string()
- yield self.check_user_in_room(room_id, user_id)
+ await self.check_user_in_room(room_id, user_id)
# We currently require the user is a "moderator" in the room. We do this
# by checking if they would (theoretically) be able to change the
# m.room.canonical_alias events
- power_level_event = yield self.state.get_current_state(
+ power_level_event = await self.state.get_current_state(
room_id, EventTypes.PowerLevels, ""
)
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 0ace7b787d..667ad20428 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -413,12 +413,6 @@ class GenericWorkerTyping(object):
# map room IDs to sets of users currently typing
self._room_typing = {}
- def stream_positions(self):
- # We must update this typing token from the response of the previous
- # sync. In particular, the stream id may "reset" back to zero/a low
- # value which we *must* use for the next replication request.
- return {"typing": self._latest_room_serial}
-
def process_replication_rows(self, token, rows):
if self._latest_room_serial > token:
# The master has gone backwards. To prevent inconsistent data, just
@@ -652,20 +646,11 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
else:
self.send_handler = None
- async def on_rdata(self, stream_name, token, rows):
- await super(GenericWorkerReplicationHandler, self).on_rdata(
- stream_name, token, rows
- )
- await self.process_and_notify(stream_name, token, rows)
+ async def on_rdata(self, stream_name, instance_name, token, rows):
+ await super().on_rdata(stream_name, instance_name, token, rows)
+ await self._process_and_notify(stream_name, instance_name, token, rows)
- def get_streams_to_replicate(self):
- args = super(GenericWorkerReplicationHandler, self).get_streams_to_replicate()
- args.update(self.typing_handler.stream_positions())
- if self.send_handler:
- args.update(self.send_handler.stream_positions())
- return args
-
- async def process_and_notify(self, stream_name, token, rows):
+ async def _process_and_notify(self, stream_name, instance_name, token, rows):
try:
if self.send_handler:
await self.send_handler.process_replication_rows(
@@ -799,9 +784,6 @@ class FederationSenderHandler(object):
def wake_destination(self, server: str):
self.federation_sender.wake_destination(server)
- def stream_positions(self):
- return {"federation": self.federation_position}
-
async def process_replication_rows(self, stream_name, token, rows):
# The federation stream contains things that we want to send out, e.g.
# presence, typing, etc.
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index a0071fec94..687cd841ac 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -883,18 +883,37 @@ class FederationClient(FederationBase):
def get_public_rooms(
self,
- destination,
- limit=None,
- since_token=None,
- search_filter=None,
- include_all_networks=False,
- third_party_instance_id=None,
+ remote_server: str,
+ limit: Optional[int] = None,
+ since_token: Optional[str] = None,
+ search_filter: Optional[Dict] = None,
+ include_all_networks: bool = False,
+ third_party_instance_id: Optional[str] = None,
):
- if destination == self.server_name:
- return
+ """Get the list of public rooms from a remote homeserver
+
+ Args:
+ remote_server: The name of the remote server
+ limit: Maximum amount of rooms to return
+ since_token: Used for result pagination
+ search_filter: A filter dictionary to send the remote homeserver
+ and filter the result set
+ include_all_networks: Whether to include results from all third party instances
+ third_party_instance_id: Whether to only include results from a specific third
+ party instance
+
+ Returns:
+ Deferred[Dict[str, Any]]: The response from the remote server, or None if
+ `remote_server` is the same as the local server_name
+ Raises:
+ HttpResponseException: There was an exception returned from the remote server
+ SynapseException: M_FORBIDDEN when the remote server has disallowed publicRoom
+ requests over federation
+
+ """
return self.transport_layer.get_public_rooms(
- destination,
+ remote_server,
limit,
since_token,
search_filter,
@@ -957,14 +976,13 @@ class FederationClient(FederationBase):
return signed_events
- @defer.inlineCallbacks
- def forward_third_party_invite(self, destinations, room_id, event_dict):
+ async def forward_third_party_invite(self, destinations, room_id, event_dict):
for destination in destinations:
if destination == self.server_name:
continue
try:
- yield self.transport_layer.exchange_third_party_invite(
+ await self.transport_layer.exchange_third_party_invite(
destination=destination, room_id=room_id, event_dict=event_dict
)
return None
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index e1700ca8aa..52f4f54215 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -31,6 +31,7 @@ Events are replicated via a separate events stream.
import logging
from collections import namedtuple
+from typing import Dict, List, Tuple, Type
from six import iteritems
@@ -56,21 +57,35 @@ class FederationRemoteSendQueue(object):
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
- self.presence_map = {} # Pending presence map user_id -> UserPresenceState
- self.presence_changed = SortedDict() # Stream position -> list[user_id]
+ # Pending presence map user_id -> UserPresenceState
+ self.presence_map = {} # type: Dict[str, UserPresenceState]
+
+ # Stream position -> list[user_id]
+ self.presence_changed = SortedDict() # type: SortedDict[int, List[str]]
# Stores the destinations we need to explicitly send presence to about a
# given user.
# Stream position -> (user_id, destinations)
- self.presence_destinations = SortedDict()
+ self.presence_destinations = (
+ SortedDict()
+ ) # type: SortedDict[int, Tuple[str, List[str]]]
+
+ # (destination, key) -> EDU
+ self.keyed_edu = {} # type: Dict[Tuple[str, tuple], Edu]
- self.keyed_edu = {} # (destination, key) -> EDU
- self.keyed_edu_changed = SortedDict() # stream position -> (destination, key)
+ # stream position -> (destination, key)
+ self.keyed_edu_changed = (
+ SortedDict()
+ ) # type: SortedDict[int, Tuple[str, tuple]]
- self.edus = SortedDict() # stream position -> Edu
+ self.edus = SortedDict() # type: SortedDict[int, Edu]
+ # stream ID for the next entry into presence_changed/keyed_edu_changed/edus.
self.pos = 1
- self.pos_time = SortedDict()
+
+ # map from stream ID to the time that stream entry was generated, so that we
+ # can clear out entries after a while
+ self.pos_time = SortedDict() # type: SortedDict[int, int]
# EVERYTHING IS SAD. In particular, python only makes new scopes when
# we make a new function, so we need to make a new function so the inner
@@ -158,8 +173,10 @@ class FederationRemoteSendQueue(object):
for edu_key in self.keyed_edu_changed.values():
live_keys.add(edu_key)
- to_del = [edu_key for edu_key in self.keyed_edu if edu_key not in live_keys]
- for edu_key in to_del:
+ keys_to_del = [
+ edu_key for edu_key in self.keyed_edu if edu_key not in live_keys
+ ]
+ for edu_key in keys_to_del:
del self.keyed_edu[edu_key]
# Delete things out of edu map
@@ -250,19 +267,23 @@ class FederationRemoteSendQueue(object):
self._clear_queue_before_pos(token)
async def get_replication_rows(
- self, from_token, to_token, limit, federation_ack=None
- ):
+ self, instance_name: str, from_token: int, to_token: int, target_row_count: int
+ ) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
"""Get rows to be sent over federation between the two tokens
Args:
- from_token (int)
- to_token(int)
- limit (int)
- federation_ack (int): Optional. The position where the worker is
- explicitly acknowledged it has handled. Allows us to drop
- data from before that point
+ instance_name: the name of the current process
+ from_token: the previous stream token: the starting point for fetching the
+ updates
+ to_token: the new stream token: the point to get updates up to
+ target_row_count: a target for the number of rows to be returned.
+
+ Returns: a triplet `(updates, new_last_token, limited)`, where:
+ * `updates` is a list of `(token, row)` entries.
+ * `new_last_token` is the new position in stream.
+ * `limited` is whether there are more updates to fetch.
"""
- # TODO: Handle limit.
+ # TODO: Handle target_row_count.
# To handle restarts where we wrap around
if from_token > self.pos:
@@ -270,12 +291,7 @@ class FederationRemoteSendQueue(object):
# list of tuple(int, BaseFederationRow), where the first is the position
# of the federation stream.
- rows = []
-
- # There should be only one reader, so lets delete everything its
- # acknowledged its seen.
- if federation_ack:
- self._clear_queue_before_pos(federation_ack)
+ rows = [] # type: List[Tuple[int, BaseFederationRow]]
# Fetch changed presence
i = self.presence_changed.bisect_right(from_token)
@@ -332,7 +348,11 @@ class FederationRemoteSendQueue(object):
# Sort rows based on pos
rows.sort()
- return [(pos, row.TypeId, row.to_data()) for pos, row in rows]
+ return (
+ [(pos, (row.TypeId, row.to_data())) for pos, row in rows],
+ to_token,
+ False,
+ )
class BaseFederationRow(object):
@@ -341,7 +361,7 @@ class BaseFederationRow(object):
Specifies how to identify, serialize and deserialize the different types.
"""
- TypeId = None # Unique string that ids the type. Must be overriden in sub classes.
+ TypeId = "" # Unique string that ids the type. Must be overriden in sub classes.
@staticmethod
def from_data(data):
@@ -454,10 +474,14 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu
buff.edus.setdefault(self.edu.destination, []).append(self.edu)
-TypeToRow = {
- Row.TypeId: Row
- for Row in (PresenceRow, PresenceDestinationsRow, KeyedEduRow, EduRow,)
-}
+_rowtypes = (
+ PresenceRow,
+ PresenceDestinationsRow,
+ KeyedEduRow,
+ EduRow,
+) # type: Tuple[Type[BaseFederationRow], ...]
+
+TypeToRow = {Row.TypeId: Row for Row in _rowtypes}
ParsedFederationStreamData = namedtuple(
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index a477578e44..d473576902 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Dict, Hashable, Iterable, List, Optional, Set
+from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple
from six import itervalues
@@ -498,14 +498,16 @@ class FederationSender(object):
self._get_per_destination_queue(destination).attempt_new_transaction()
- def get_current_token(self) -> int:
+ @staticmethod
+ def get_current_token() -> int:
# Dummy implementation for case where federation sender isn't offloaded
# to a worker.
return 0
+ @staticmethod
async def get_replication_rows(
- self, from_token, to_token, limit, federation_ack=None
- ):
+ instance_name: str, from_token: int, to_token: int, target_row_count: int
+ ) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
# Dummy implementation for case where federation sender isn't offloaded
# to a worker.
- return []
+ return [], 0, False
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index e13cd20ffa..276a2b596f 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -15,11 +15,10 @@
# limitations under the License.
import datetime
import logging
-from typing import Dict, Hashable, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Tuple
from prometheus_client import Counter
-import synapse.server
from synapse.api.errors import (
FederationDeniedError,
HttpResponseException,
@@ -34,6 +33,9 @@ from synapse.storage.presence import UserPresenceState
from synapse.types import ReadReceipt
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
+if TYPE_CHECKING:
+ import synapse.server
+
# This is defined in the Matrix spec and enforced by the receiver.
MAX_EDUS_PER_TRANSACTION = 100
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 3c2a02a3b3..a2752a54a5 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -13,11 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import List
+from typing import TYPE_CHECKING, List
from canonicaljson import json
-import synapse.server
from synapse.api.errors import HttpResponseException
from synapse.events import EventBase
from synapse.federation.persistence import TransactionActions
@@ -31,6 +30,9 @@ from synapse.logging.opentracing import (
)
from synapse.util.metrics import measure_func
+if TYPE_CHECKING:
+ import synapse.server
+
logger = logging.getLogger(__name__)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 383e3fdc8b..060bf07197 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -15,13 +15,14 @@
# limitations under the License.
import logging
-from typing import Any, Dict
+from typing import Any, Dict, Optional
from six.moves import urllib
from twisted.internet import defer
from synapse.api.constants import Membership
+from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.api.urls import (
FEDERATION_UNSTABLE_PREFIX,
FEDERATION_V1_PREFIX,
@@ -326,18 +327,25 @@ class TransportLayerClient(object):
@log_function
def get_public_rooms(
self,
- remote_server,
- limit,
- since_token,
- search_filter=None,
- include_all_networks=False,
- third_party_instance_id=None,
+ remote_server: str,
+ limit: Optional[int] = None,
+ since_token: Optional[str] = None,
+ search_filter: Optional[Dict] = None,
+ include_all_networks: bool = False,
+ third_party_instance_id: Optional[str] = None,
):
+ """Get the list of public rooms from a remote homeserver
+
+ See synapse.federation.federation_client.FederationClient.get_public_rooms for
+ more information.
+ """
if search_filter:
# this uses MSC2197 (Search Filtering over Federation)
path = _create_v1_path("/publicRooms")
- data = {"include_all_networks": "true" if include_all_networks else "false"}
+ data = {
+ "include_all_networks": "true" if include_all_networks else "false"
+ } # type: Dict[str, Any]
if third_party_instance_id:
data["third_party_instance_id"] = third_party_instance_id
if limit:
@@ -347,9 +355,19 @@ class TransportLayerClient(object):
data["filter"] = search_filter
- response = yield self.client.post_json(
- destination=remote_server, path=path, data=data, ignore_backoff=True
- )
+ try:
+ response = yield self.client.post_json(
+ destination=remote_server, path=path, data=data, ignore_backoff=True
+ )
+ except HttpResponseException as e:
+ if e.code == 403:
+ raise SynapseError(
+ 403,
+ "You are not allowed to view the public rooms list of %s"
+ % (remote_server,),
+ errcode=Codes.FORBIDDEN,
+ )
+ raise
else:
path = _create_v1_path("/publicRooms")
@@ -363,9 +381,19 @@ class TransportLayerClient(object):
if since_token:
args["since"] = [since_token]
- response = yield self.client.get_json(
- destination=remote_server, path=path, args=args, ignore_backoff=True
- )
+ try:
+ response = yield self.client.get_json(
+ destination=remote_server, path=path, args=args, ignore_backoff=True
+ )
+ except HttpResponseException as e:
+ if e.code == 403:
+ raise SynapseError(
+ 403,
+ "You are not allowed to view the public rooms list of %s"
+ % (remote_server,),
+ errcode=Codes.FORBIDDEN,
+ )
+ raise
return response
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index 4f0dc0a209..4acb4fa489 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -748,17 +748,18 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
raise NotImplementedError()
- @defer.inlineCallbacks
- def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
+ async def remove_user_from_group(
+ self, group_id, user_id, requester_user_id, content
+ ):
"""Remove a user from the group; either a user is leaving or an admin
kicked them.
"""
- yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+ await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
is_kick = False
if requester_user_id != user_id:
- is_admin = yield self.store.is_user_admin_in_group(
+ is_admin = await self.store.is_user_admin_in_group(
group_id, requester_user_id
)
if not is_admin:
@@ -766,30 +767,29 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
is_kick = True
- yield self.store.remove_user_from_group(group_id, user_id)
+ await self.store.remove_user_from_group(group_id, user_id)
if is_kick:
if self.hs.is_mine_id(user_id):
groups_local = self.hs.get_groups_local_handler()
- yield groups_local.user_removed_from_group(group_id, user_id, {})
+ await groups_local.user_removed_from_group(group_id, user_id, {})
else:
- yield self.transport_client.remove_user_from_group_notification(
+ await self.transport_client.remove_user_from_group_notification(
get_domain_from_id(user_id), group_id, user_id, {}
)
if not self.hs.is_mine_id(user_id):
- yield self.store.maybe_delete_remote_profile_cache(user_id)
+ await self.store.maybe_delete_remote_profile_cache(user_id)
# Delete group if the last user has left
- users = yield self.store.get_users_in_group(group_id, include_private=True)
+ users = await self.store.get_users_in_group(group_id, include_private=True)
if not users:
- yield self.store.delete_group(group_id)
+ await self.store.delete_group(group_id)
return {}
- @defer.inlineCallbacks
- def create_group(self, group_id, requester_user_id, content):
- group = yield self.check_group_is_ours(group_id, requester_user_id)
+ async def create_group(self, group_id, requester_user_id, content):
+ group = await self.check_group_is_ours(group_id, requester_user_id)
logger.info("Attempting to create group with ID: %r", group_id)
@@ -799,7 +799,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
if group:
raise SynapseError(400, "Group already exists")
- is_admin = yield self.auth.is_server_admin(
+ is_admin = await self.auth.is_server_admin(
UserID.from_string(requester_user_id)
)
if not is_admin:
@@ -822,7 +822,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
long_description = profile.get("long_description")
user_profile = content.get("user_profile", {})
- yield self.store.create_group(
+ await self.store.create_group(
group_id,
requester_user_id,
name=name,
@@ -834,7 +834,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
if not self.hs.is_mine_id(requester_user_id):
remote_attestation = content["attestation"]
- yield self.attestations.verify_attestation(
+ await self.attestations.verify_attestation(
remote_attestation, user_id=requester_user_id, group_id=group_id
)
@@ -845,7 +845,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
local_attestation = None
remote_attestation = None
- yield self.store.add_user_to_group(
+ await self.store.add_user_to_group(
group_id,
requester_user_id,
is_admin=True,
@@ -855,7 +855,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
)
if not self.hs.is_mine_id(requester_user_id):
- yield self.store.add_remote_profile_cache(
+ await self.store.add_remote_profile_cache(
requester_user_id,
displayname=user_profile.get("displayname"),
avatar_url=user_profile.get("avatar_url"),
@@ -863,8 +863,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {"group_id": group_id}
- @defer.inlineCallbacks
- def delete_group(self, group_id, requester_user_id):
+ async def delete_group(self, group_id, requester_user_id):
"""Deletes a group, kicking out all current members.
Only group admins or server admins can call this request
@@ -877,14 +876,14 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
Deferred
"""
- yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+ await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
# Only server admins or group admins can delete groups.
- is_admin = yield self.store.is_user_admin_in_group(group_id, requester_user_id)
+ is_admin = await self.store.is_user_admin_in_group(group_id, requester_user_id)
if not is_admin:
- is_admin = yield self.auth.is_server_admin(
+ is_admin = await self.auth.is_server_admin(
UserID.from_string(requester_user_id)
)
@@ -892,18 +891,17 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
raise SynapseError(403, "User is not an admin")
# Before deleting the group lets kick everyone out of it
- users = yield self.store.get_users_in_group(group_id, include_private=True)
+ users = await self.store.get_users_in_group(group_id, include_private=True)
- @defer.inlineCallbacks
- def _kick_user_from_group(user_id):
+ async def _kick_user_from_group(user_id):
if self.hs.is_mine_id(user_id):
groups_local = self.hs.get_groups_local_handler()
- yield groups_local.user_removed_from_group(group_id, user_id, {})
+ await groups_local.user_removed_from_group(group_id, user_id, {})
else:
- yield self.transport_client.remove_user_from_group_notification(
+ await self.transport_client.remove_user_from_group_notification(
get_domain_from_id(user_id), group_id, user_id, {}
)
- yield self.store.maybe_delete_remote_profile_cache(user_id)
+ await self.store.maybe_delete_remote_profile_cache(user_id)
# We kick users out in the order of:
# 1. Non-admins
@@ -922,11 +920,11 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
else:
non_admins.append(u["user_id"])
- yield concurrently_execute(_kick_user_from_group, non_admins, 10)
- yield concurrently_execute(_kick_user_from_group, admins, 10)
- yield _kick_user_from_group(requester_user_id)
+ await concurrently_execute(_kick_user_from_group, non_admins, 10)
+ await concurrently_execute(_kick_user_from_group, admins, 10)
+ await _kick_user_from_group(requester_user_id)
- yield self.store.delete_group(group_id)
+ await self.store.delete_group(group_id)
def _parse_join_policy_from_contents(content):
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 51413d910e..3b781d9836 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -126,30 +126,28 @@ class BaseHandler(object):
retry_after_ms=int(1000 * (time_allowed - time_now))
)
- @defer.inlineCallbacks
- def maybe_kick_guest_users(self, event, context=None):
+ async def maybe_kick_guest_users(self, event, context=None):
# Technically this function invalidates current_state by changing it.
# Hopefully this isn't that important to the caller.
if event.type == EventTypes.GuestAccess:
guest_access = event.content.get("guest_access", "forbidden")
if guest_access != "can_join":
if context:
- current_state_ids = yield context.get_current_state_ids()
- current_state = yield self.store.get_events(
+ current_state_ids = await context.get_current_state_ids()
+ current_state = await self.store.get_events(
list(current_state_ids.values())
)
else:
- current_state = yield self.state_handler.get_current_state(
+ current_state = await self.state_handler.get_current_state(
event.room_id
)
current_state = list(current_state.values())
logger.info("maybe_kick_guest_users %r", current_state)
- yield self.kick_guest_users(current_state)
+ await self.kick_guest_users(current_state)
- @defer.inlineCallbacks
- def kick_guest_users(self, current_state):
+ async def kick_guest_users(self, current_state):
for member_event in current_state:
try:
if member_event.type != EventTypes.Member:
@@ -180,7 +178,7 @@ class BaseHandler(object):
# homeserver.
requester = synapse.types.create_requester(target_user, is_guest=True)
handler = self.hs.get_room_member_handler()
- yield handler.update_membership(
+ await handler.update_membership(
requester,
target_user,
member_event.room_id,
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 53e5f585d9..f2f16b1e43 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -86,8 +86,7 @@ class DirectoryHandler(BaseHandler):
room_alias, room_id, servers, creator=creator
)
- @defer.inlineCallbacks
- def create_association(
+ async def create_association(
self,
requester: Requester,
room_alias: RoomAlias,
@@ -129,10 +128,10 @@ class DirectoryHandler(BaseHandler):
else:
# Server admins are not subject to the same constraints as normal
# users when creating an alias (e.g. being in the room).
- is_admin = yield self.auth.is_server_admin(requester.user)
+ is_admin = await self.auth.is_server_admin(requester.user)
if (self.require_membership and check_membership) and not is_admin:
- rooms_for_user = yield self.store.get_rooms_for_user(user_id)
+ rooms_for_user = await self.store.get_rooms_for_user(user_id)
if room_id not in rooms_for_user:
raise AuthError(
403, "You must be in the room to create an alias for it"
@@ -149,7 +148,7 @@ class DirectoryHandler(BaseHandler):
# per alias creation rule?
raise SynapseError(403, "Not allowed to create alias")
- can_create = yield self.can_modify_alias(room_alias, user_id=user_id)
+ can_create = await self.can_modify_alias(room_alias, user_id=user_id)
if not can_create:
raise AuthError(
400,
@@ -157,10 +156,9 @@ class DirectoryHandler(BaseHandler):
errcode=Codes.EXCLUSIVE,
)
- yield self._create_association(room_alias, room_id, servers, creator=user_id)
+ await self._create_association(room_alias, room_id, servers, creator=user_id)
- @defer.inlineCallbacks
- def delete_association(self, requester: Requester, room_alias: RoomAlias):
+ async def delete_association(self, requester: Requester, room_alias: RoomAlias):
"""Remove an alias from the directory
(this is only meant for human users; AS users should call
@@ -184,7 +182,7 @@ class DirectoryHandler(BaseHandler):
user_id = requester.user.to_string()
try:
- can_delete = yield self._user_can_delete_alias(room_alias, user_id)
+ can_delete = await self._user_can_delete_alias(room_alias, user_id)
except StoreError as e:
if e.code == 404:
raise NotFoundError("Unknown room alias")
@@ -193,7 +191,7 @@ class DirectoryHandler(BaseHandler):
if not can_delete:
raise AuthError(403, "You don't have permission to delete the alias.")
- can_delete = yield self.can_modify_alias(room_alias, user_id=user_id)
+ can_delete = await self.can_modify_alias(room_alias, user_id=user_id)
if not can_delete:
raise SynapseError(
400,
@@ -201,10 +199,10 @@ class DirectoryHandler(BaseHandler):
errcode=Codes.EXCLUSIVE,
)
- room_id = yield self._delete_association(room_alias)
+ room_id = await self._delete_association(room_alias)
try:
- yield self._update_canonical_alias(requester, user_id, room_id, room_alias)
+ await self._update_canonical_alias(requester, user_id, room_id, room_alias)
except AuthError as e:
logger.info("Failed to update alias events: %s", e)
@@ -296,15 +294,14 @@ class DirectoryHandler(BaseHandler):
Codes.NOT_FOUND,
)
- @defer.inlineCallbacks
- def _update_canonical_alias(
+ async def _update_canonical_alias(
self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias
):
"""
Send an updated canonical alias event if the removed alias was set as
the canonical alias or listed in the alt_aliases field.
"""
- alias_event = yield self.state.get_current_state(
+ alias_event = await self.state.get_current_state(
room_id, EventTypes.CanonicalAlias, ""
)
@@ -335,7 +332,7 @@ class DirectoryHandler(BaseHandler):
del content["alt_aliases"]
if send_update:
- yield self.event_creation_handler.create_and_send_nonmember_event(
+ await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.CanonicalAlias,
@@ -376,8 +373,7 @@ class DirectoryHandler(BaseHandler):
# either no interested services, or no service with an exclusive lock
return defer.succeed(True)
- @defer.inlineCallbacks
- def _user_can_delete_alias(self, alias: RoomAlias, user_id: str):
+ async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str):
"""Determine whether a user can delete an alias.
One of the following must be true:
@@ -388,24 +384,23 @@ class DirectoryHandler(BaseHandler):
for the current room.
"""
- creator = yield self.store.get_room_alias_creator(alias.to_string())
+ creator = await self.store.get_room_alias_creator(alias.to_string())
if creator is not None and creator == user_id:
return True
# Resolve the alias to the corresponding room.
- room_mapping = yield self.get_association(alias)
+ room_mapping = await self.get_association(alias)
room_id = room_mapping["room_id"]
if not room_id:
return False
- res = yield self.auth.check_can_change_room_list(
+ res = await self.auth.check_can_change_room_list(
room_id, UserID.from_string(user_id)
)
return res
- @defer.inlineCallbacks
- def edit_published_room_list(
+ async def edit_published_room_list(
self, requester: Requester, room_id: str, visibility: str
):
"""Edit the entry of the room in the published room list.
@@ -433,11 +428,11 @@ class DirectoryHandler(BaseHandler):
403, "This user is not permitted to publish rooms to the room list"
)
- room = yield self.store.get_room(room_id)
+ room = await self.store.get_room(room_id)
if room is None:
raise SynapseError(400, "Unknown room")
- can_change_room_list = yield self.auth.check_can_change_room_list(
+ can_change_room_list = await self.auth.check_can_change_room_list(
room_id, requester.user
)
if not can_change_room_list:
@@ -449,8 +444,8 @@ class DirectoryHandler(BaseHandler):
making_public = visibility == "public"
if making_public:
- room_aliases = yield self.store.get_aliases_for_room(room_id)
- canonical_alias = yield self.store.get_canonical_alias_for_room(room_id)
+ room_aliases = await self.store.get_aliases_for_room(room_id)
+ canonical_alias = await self.store.get_canonical_alias_for_room(room_id)
if canonical_alias:
room_aliases.append(canonical_alias)
@@ -462,7 +457,7 @@ class DirectoryHandler(BaseHandler):
# per alias creation rule?
raise SynapseError(403, "Not allowed to publish room")
- yield self.store.set_room_is_public(room_id, making_public)
+ await self.store.set_room_is_public(room_id, making_public)
@defer.inlineCallbacks
def edit_published_appservice_room_list(
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 41b96c0a73..4e5c645525 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -2562,9 +2562,8 @@ class FederationHandler(BaseHandler):
"missing": [e.event_id for e in missing_locals],
}
- @defer.inlineCallbacks
@log_function
- def exchange_third_party_invite(
+ async def exchange_third_party_invite(
self, sender_user_id, target_user_id, room_id, signed
):
third_party_invite = {"signed": signed}
@@ -2580,16 +2579,16 @@ class FederationHandler(BaseHandler):
"state_key": target_user_id,
}
- if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
- room_version = yield self.store.get_room_version_id(room_id)
+ if await self.auth.check_host_in_room(room_id, self.hs.hostname):
+ room_version = await self.store.get_room_version_id(room_id)
builder = self.event_builder_factory.new(room_version, event_dict)
EventValidator().validate_builder(builder)
- event, context = yield self.event_creation_handler.create_new_client_event(
+ event, context = await self.event_creation_handler.create_new_client_event(
builder=builder
)
- event_allowed = yield self.third_party_event_rules.check_event_allowed(
+ event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
@@ -2601,7 +2600,7 @@ class FederationHandler(BaseHandler):
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
- event, context = yield self.add_display_name_to_third_party_invite(
+ event, context = await self.add_display_name_to_third_party_invite(
room_version, event_dict, event, context
)
@@ -2612,19 +2611,19 @@ class FederationHandler(BaseHandler):
event.internal_metadata.send_on_behalf_of = self.hs.hostname
try:
- yield self.auth.check_from_context(room_version, event, context)
+ await self.auth.check_from_context(room_version, event, context)
except AuthError as e:
logger.warning("Denying new third party invite %r because %s", event, e)
raise e
- yield self._check_signature(event, context)
+ await self._check_signature(event, context)
# We retrieve the room member handler here as to not cause a cyclic dependency
member_handler = self.hs.get_room_member_handler()
- yield member_handler.send_membership_event(None, event, context)
+ await member_handler.send_membership_event(None, event, context)
else:
destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)}
- yield self.federation_client.forward_third_party_invite(
+ await self.federation_client.forward_third_party_invite(
destinations, room_id, event_dict
)
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index ad22415782..ca5c83811a 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -284,15 +284,14 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
set_group_join_policy = _create_rerouter("set_group_join_policy")
- @defer.inlineCallbacks
- def create_group(self, group_id, user_id, content):
+ async def create_group(self, group_id, user_id, content):
"""Create a group
"""
logger.info("Asking to create group with ID: %r", group_id)
if self.is_mine_id(group_id):
- res = yield self.groups_server_handler.create_group(
+ res = await self.groups_server_handler.create_group(
group_id, user_id, content
)
local_attestation = None
@@ -301,10 +300,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
local_attestation = self.attestations.create_attestation(group_id, user_id)
content["attestation"] = local_attestation
- content["user_profile"] = yield self.profile_handler.get_profile(user_id)
+ content["user_profile"] = await self.profile_handler.get_profile(user_id)
try:
- res = yield self.transport_client.create_group(
+ res = await self.transport_client.create_group(
get_domain_from_id(group_id), group_id, user_id, content
)
except HttpResponseException as e:
@@ -313,7 +312,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
raise SynapseError(502, "Failed to contact group server")
remote_attestation = res["attestation"]
- yield self.attestations.verify_attestation(
+ await self.attestations.verify_attestation(
remote_attestation,
group_id=group_id,
user_id=user_id,
@@ -321,7 +320,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
)
is_publicised = content.get("publicise", False)
- token = yield self.store.register_user_group_membership(
+ token = await self.store.register_user_group_membership(
group_id,
user_id,
membership="join",
@@ -482,12 +481,13 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {"state": "invite", "user_profile": user_profile}
- @defer.inlineCallbacks
- def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
+ async def remove_user_from_group(
+ self, group_id, user_id, requester_user_id, content
+ ):
"""Remove a user from a group
"""
if user_id == requester_user_id:
- token = yield self.store.register_user_group_membership(
+ token = await self.store.register_user_group_membership(
group_id, user_id, membership="leave"
)
self.notifier.on_new_event("groups_key", token, users=[user_id])
@@ -496,13 +496,13 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
# retry if the group server is currently down.
if self.is_mine_id(group_id):
- res = yield self.groups_server_handler.remove_user_from_group(
+ res = await self.groups_server_handler.remove_user_from_group(
group_id, user_id, requester_user_id, content
)
else:
content["requester_user_id"] = requester_user_id
try:
- res = yield self.transport_client.remove_user_from_group(
+ res = await self.transport_client.remove_user_from_group(
get_domain_from_id(group_id),
group_id,
requester_user_id,
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 522271eed1..a324f09340 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -626,8 +626,7 @@ class EventCreationHandler(object):
msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
- @defer.inlineCallbacks
- def send_nonmember_event(self, requester, event, context, ratelimit=True):
+ async def send_nonmember_event(self, requester, event, context, ratelimit=True):
"""
Persists and notifies local clients and federation of an event.
@@ -647,7 +646,7 @@ class EventCreationHandler(object):
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if event.is_state():
- prev_state = yield self.deduplicate_state_event(event, context)
+ prev_state = await self.deduplicate_state_event(event, context)
if prev_state is not None:
logger.info(
"Not bothering to persist state event %s duplicated by %s",
@@ -656,7 +655,7 @@ class EventCreationHandler(object):
)
return prev_state
- yield self.handle_new_client_event(
+ await self.handle_new_client_event(
requester=requester, event=event, context=context, ratelimit=ratelimit
)
@@ -683,8 +682,7 @@ class EventCreationHandler(object):
return prev_event
return
- @defer.inlineCallbacks
- def create_and_send_nonmember_event(
+ async def create_and_send_nonmember_event(
self, requester, event_dict, ratelimit=True, txn_id=None
):
"""
@@ -698,8 +696,8 @@ class EventCreationHandler(object):
# a situation where event persistence can't keep up, causing
# extremities to pile up, which in turn leads to state resolution
# taking longer.
- with (yield self.limiter.queue(event_dict["room_id"])):
- event, context = yield self.create_event(
+ with (await self.limiter.queue(event_dict["room_id"])):
+ event, context = await self.create_event(
requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id
)
@@ -709,7 +707,7 @@ class EventCreationHandler(object):
spam_error = "Spam is not permitted here"
raise SynapseError(403, spam_error, Codes.FORBIDDEN)
- yield self.send_nonmember_event(
+ await self.send_nonmember_event(
requester, event, context, ratelimit=ratelimit
)
return event
@@ -770,8 +768,7 @@ class EventCreationHandler(object):
return (event, context)
@measure_func("handle_new_client_event")
- @defer.inlineCallbacks
- def handle_new_client_event(
+ async def handle_new_client_event(
self, requester, event, context, ratelimit=True, extra_users=[]
):
"""Processes a new event. This includes checking auth, persisting it,
@@ -794,9 +791,9 @@ class EventCreationHandler(object):
):
room_version = event.content.get("room_version", RoomVersions.V1.identifier)
else:
- room_version = yield self.store.get_room_version_id(event.room_id)
+ room_version = await self.store.get_room_version_id(event.room_id)
- event_allowed = yield self.third_party_event_rules.check_event_allowed(
+ event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
@@ -805,7 +802,7 @@ class EventCreationHandler(object):
)
try:
- yield self.auth.check_from_context(room_version, event, context)
+ await self.auth.check_from_context(room_version, event, context)
except AuthError as err:
logger.warning("Denying new event %r because %s", event, err)
raise err
@@ -818,7 +815,7 @@ class EventCreationHandler(object):
logger.exception("Failed to encode content: %r", event.content)
raise
- yield self.action_generator.handle_push_actions_for_event(event, context)
+ await self.action_generator.handle_push_actions_for_event(event, context)
# reraise does not allow inlineCallbacks to preserve the stacktrace, so we
# hack around with a try/finally instead.
@@ -826,7 +823,7 @@ class EventCreationHandler(object):
try:
# If we're a worker we need to hit out to the master.
if self.config.worker_app:
- yield self.send_event_to_master(
+ await self.send_event_to_master(
event_id=event.event_id,
store=self.store,
requester=requester,
@@ -838,7 +835,7 @@ class EventCreationHandler(object):
success = True
return
- yield self.persist_and_notify_client_event(
+ await self.persist_and_notify_client_event(
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
)
@@ -883,8 +880,7 @@ class EventCreationHandler(object):
Codes.BAD_ALIAS,
)
- @defer.inlineCallbacks
- def persist_and_notify_client_event(
+ async def persist_and_notify_client_event(
self, requester, event, context, ratelimit=True, extra_users=[]
):
"""Called when we have fully built the event, have already
@@ -901,7 +897,7 @@ class EventCreationHandler(object):
# user is actually admin or not).
is_admin_redaction = False
if event.type == EventTypes.Redaction:
- original_event = yield self.store.get_event(
+ original_event = await self.store.get_event(
event.redacts,
redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
@@ -913,11 +909,11 @@ class EventCreationHandler(object):
original_event and event.sender != original_event.sender
)
- yield self.base_handler.ratelimit(
+ await self.base_handler.ratelimit(
requester, is_admin_redaction=is_admin_redaction
)
- yield self.base_handler.maybe_kick_guest_users(event, context)
+ await self.base_handler.maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias:
# Validate a newly added alias or newly added alt_aliases.
@@ -927,7 +923,7 @@ class EventCreationHandler(object):
original_event_id = event.unsigned.get("replaces_state")
if original_event_id:
- original_event = yield self.store.get_event(original_event_id)
+ original_event = await self.store.get_event(original_event_id)
if original_event:
original_alias = original_event.content.get("alias", None)
@@ -937,7 +933,7 @@ class EventCreationHandler(object):
room_alias_str = event.content.get("alias", None)
directory_handler = self.hs.get_handlers().directory_handler
if room_alias_str and room_alias_str != original_alias:
- yield self._validate_canonical_alias(
+ await self._validate_canonical_alias(
directory_handler, room_alias_str, event.room_id
)
@@ -957,7 +953,7 @@ class EventCreationHandler(object):
new_alt_aliases = set(alt_aliases) - set(original_alt_aliases)
if new_alt_aliases:
for alias_str in new_alt_aliases:
- yield self._validate_canonical_alias(
+ await self._validate_canonical_alias(
directory_handler, alias_str, event.room_id
)
@@ -969,7 +965,7 @@ class EventCreationHandler(object):
def is_inviter_member_event(e):
return e.type == EventTypes.Member and e.sender == event.sender
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = await context.get_current_state_ids()
state_to_include_ids = [
e_id
@@ -978,7 +974,7 @@ class EventCreationHandler(object):
or k == (EventTypes.Member, event.sender)
]
- state_to_include = yield self.store.get_events(state_to_include_ids)
+ state_to_include = await self.store.get_events(state_to_include_ids)
event.unsigned["invite_room_state"] = [
{
@@ -996,8 +992,8 @@ class EventCreationHandler(object):
# way? If we have been invited by a remote server, we need
# to get them to sign the event.
- returned_invite = yield defer.ensureDeferred(
- federation_handler.send_invite(invitee.domain, event)
+ returned_invite = await federation_handler.send_invite(
+ invitee.domain, event
)
event.unsigned.pop("room_state", None)
@@ -1005,7 +1001,7 @@ class EventCreationHandler(object):
event.signatures.update(returned_invite.signatures)
if event.type == EventTypes.Redaction:
- original_event = yield self.store.get_event(
+ original_event = await self.store.get_event(
event.redacts,
redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
@@ -1021,14 +1017,14 @@ class EventCreationHandler(object):
if original_event.room_id != event.room_id:
raise SynapseError(400, "Cannot redact event from a different room")
- prev_state_ids = yield context.get_prev_state_ids()
- auth_events_ids = yield self.auth.compute_auth_events(
+ prev_state_ids = await context.get_prev_state_ids()
+ auth_events_ids = await self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
- auth_events = yield self.store.get_events(auth_events_ids)
+ auth_events = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
- room_version = yield self.store.get_room_version_id(event.room_id)
+ room_version = await self.store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
if event_auth.check_redaction(
@@ -1047,11 +1043,11 @@ class EventCreationHandler(object):
event.internal_metadata.recheck_redaction = False
if event.type == EventTypes.Create:
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = await context.get_prev_state_ids()
if prev_state_ids:
raise AuthError(403, "Changing the room create event is forbidden")
- event_stream_id, max_stream_id = yield self.storage.persistence.persist_event(
+ event_stream_id, max_stream_id = await self.storage.persistence.persist_event(
event, context=context
)
@@ -1059,7 +1055,7 @@ class EventCreationHandler(object):
# If there's an expiry timestamp on the event, schedule its expiry.
self._message_handler.maybe_schedule_expiry(event)
- yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
+ await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
def _notify():
try:
@@ -1083,13 +1079,12 @@ class EventCreationHandler(object):
except Exception:
logger.exception("Error bumping presence active time")
- @defer.inlineCallbacks
- def _send_dummy_events_to_fill_extremities(self):
+ async def _send_dummy_events_to_fill_extremities(self):
"""Background task to send dummy events into rooms that have a large
number of extremities
"""
self._expire_rooms_to_exclude_from_dummy_event_insertion()
- room_ids = yield self.store.get_rooms_with_many_extremities(
+ room_ids = await self.store.get_rooms_with_many_extremities(
min_count=10,
limit=5,
room_id_filter=self._rooms_to_exclude_from_dummy_event_insertion.keys(),
@@ -1099,9 +1094,9 @@ class EventCreationHandler(object):
# For each room we need to find a joined member we can use to send
# the dummy event with.
- latest_event_ids = yield self.store.get_prev_events_for_room(room_id)
+ latest_event_ids = await self.store.get_prev_events_for_room(room_id)
- members = yield self.state.get_current_users_in_room(
+ members = await self.state.get_current_users_in_room(
room_id, latest_event_ids=latest_event_ids
)
dummy_event_sent = False
@@ -1110,7 +1105,7 @@ class EventCreationHandler(object):
continue
requester = create_requester(user_id)
try:
- event, context = yield self.create_event(
+ event, context = await self.create_event(
requester,
{
"type": "org.matrix.dummy_event",
@@ -1123,7 +1118,7 @@ class EventCreationHandler(object):
event.internal_metadata.proactively_send = False
- yield self.send_nonmember_event(
+ await self.send_nonmember_event(
requester, event, context, ratelimit=False
)
dummy_event_sent = True
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 6aa1c0f5e0..302efc1b9a 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -141,8 +141,9 @@ class BaseProfileHandler(BaseHandler):
return result["displayname"]
- @defer.inlineCallbacks
- def set_displayname(self, target_user, requester, new_displayname, by_admin=False):
+ async def set_displayname(
+ self, target_user, requester, new_displayname, by_admin=False
+ ):
"""Set the displayname of a user
Args:
@@ -158,7 +159,7 @@ class BaseProfileHandler(BaseHandler):
raise AuthError(400, "Cannot set another user's displayname")
if not by_admin and not self.hs.config.enable_set_displayname:
- profile = yield self.store.get_profileinfo(target_user.localpart)
+ profile = await self.store.get_profileinfo(target_user.localpart)
if profile.display_name:
raise SynapseError(
400,
@@ -180,15 +181,15 @@ class BaseProfileHandler(BaseHandler):
if by_admin:
requester = create_requester(target_user)
- yield self.store.set_profile_displayname(target_user.localpart, new_displayname)
+ await self.store.set_profile_displayname(target_user.localpart, new_displayname)
if self.hs.config.user_directory_search_all_users:
- profile = yield self.store.get_profileinfo(target_user.localpart)
- yield self.user_directory_handler.handle_local_profile_change(
+ profile = await self.store.get_profileinfo(target_user.localpart)
+ await self.user_directory_handler.handle_local_profile_change(
target_user.to_string(), profile
)
- yield self._update_join_states(requester, target_user)
+ await self._update_join_states(requester, target_user)
@defer.inlineCallbacks
def get_avatar_url(self, target_user):
@@ -217,8 +218,9 @@ class BaseProfileHandler(BaseHandler):
return result["avatar_url"]
- @defer.inlineCallbacks
- def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False):
+ async def set_avatar_url(
+ self, target_user, requester, new_avatar_url, by_admin=False
+ ):
"""target_user is the user whose avatar_url is to be changed;
auth_user is the user attempting to make this change."""
if not self.hs.is_mine(target_user):
@@ -228,7 +230,7 @@ class BaseProfileHandler(BaseHandler):
raise AuthError(400, "Cannot set another user's avatar_url")
if not by_admin and not self.hs.config.enable_set_avatar_url:
- profile = yield self.store.get_profileinfo(target_user.localpart)
+ profile = await self.store.get_profileinfo(target_user.localpart)
if profile.avatar_url:
raise SynapseError(
400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
@@ -243,15 +245,15 @@ class BaseProfileHandler(BaseHandler):
if by_admin:
requester = create_requester(target_user)
- yield self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
+ await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
if self.hs.config.user_directory_search_all_users:
- profile = yield self.store.get_profileinfo(target_user.localpart)
- yield self.user_directory_handler.handle_local_profile_change(
+ profile = await self.store.get_profileinfo(target_user.localpart)
+ await self.user_directory_handler.handle_local_profile_change(
target_user.to_string(), profile
)
- yield self._update_join_states(requester, target_user)
+ await self._update_join_states(requester, target_user)
@defer.inlineCallbacks
def on_profile_query(self, args):
@@ -279,21 +281,20 @@ class BaseProfileHandler(BaseHandler):
return response
- @defer.inlineCallbacks
- def _update_join_states(self, requester, target_user):
+ async def _update_join_states(self, requester, target_user):
if not self.hs.is_mine(target_user):
return
- yield self.ratelimit(requester)
+ await self.ratelimit(requester)
- room_ids = yield self.store.get_rooms_for_user(target_user.to_string())
+ room_ids = await self.store.get_rooms_for_user(target_user.to_string())
for room_id in room_ids:
handler = self.hs.get_room_member_handler()
try:
# Assume the target_user isn't a guest,
# because we don't let guests set profile or avatar data.
- yield handler.update_membership(
+ await handler.update_membership(
requester,
target_user,
room_id,
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 3a65b46ecd..1e6bdac0ad 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -145,9 +145,9 @@ class RegistrationHandler(BaseHandler):
"""Registers a new client on the server.
Args:
- localpart : The local part of the user ID to register. If None,
+ localpart: The local part of the user ID to register. If None,
one will be generated.
- password (unicode) : The password to assign to this user so they can
+ password (unicode): The password to assign to this user so they can
login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user).
user_type (str|None): type of user. One of the values from
@@ -244,7 +244,7 @@ class RegistrationHandler(BaseHandler):
fail_count += 1
if not self.hs.config.user_consent_at_registration:
- yield self._auto_join_rooms(user_id)
+ yield defer.ensureDeferred(self._auto_join_rooms(user_id))
else:
logger.info(
"Skipping auto-join for %s because consent is required at registration",
@@ -266,8 +266,7 @@ class RegistrationHandler(BaseHandler):
return user_id
- @defer.inlineCallbacks
- def _auto_join_rooms(self, user_id):
+ async def _auto_join_rooms(self, user_id):
"""Automatically joins users to auto join rooms - creating the room in the first place
if the user is the first to be created.
@@ -281,9 +280,9 @@ class RegistrationHandler(BaseHandler):
# that an auto-generated support or bot user is not a real user and will never be
# the user to create the room
should_auto_create_rooms = False
- is_real_user = yield self.store.is_real_user(user_id)
+ is_real_user = await self.store.is_real_user(user_id)
if self.hs.config.autocreate_auto_join_rooms and is_real_user:
- count = yield self.store.count_real_users()
+ count = await self.store.count_real_users()
should_auto_create_rooms = count == 1
for r in self.hs.config.auto_join_rooms:
logger.info("Auto-joining %s to %s", user_id, r)
@@ -302,7 +301,7 @@ class RegistrationHandler(BaseHandler):
# getting the RoomCreationHandler during init gives a dependency
# loop
- yield self.hs.get_room_creation_handler().create_room(
+ await self.hs.get_room_creation_handler().create_room(
fake_requester,
config={
"preset": "public_chat",
@@ -311,7 +310,7 @@ class RegistrationHandler(BaseHandler):
ratelimit=False,
)
else:
- yield self._join_user_to_room(fake_requester, r)
+ await self._join_user_to_room(fake_requester, r)
except ConsentNotGivenError as e:
# Technically not necessary to pull out this error though
# moving away from bare excepts is a good thing to do.
@@ -319,15 +318,14 @@ class RegistrationHandler(BaseHandler):
except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e)
- @defer.inlineCallbacks
- def post_consent_actions(self, user_id):
+ async def post_consent_actions(self, user_id):
"""A series of registration actions that can only be carried out once consent
has been granted
Args:
user_id (str): The user to join
"""
- yield self._auto_join_rooms(user_id)
+ await self._auto_join_rooms(user_id)
@defer.inlineCallbacks
def appservice_register(self, user_localpart, as_token):
@@ -394,14 +392,13 @@ class RegistrationHandler(BaseHandler):
self._next_generated_user_id += 1
return str(id)
- @defer.inlineCallbacks
- def _join_user_to_room(self, requester, room_identifier):
+ async def _join_user_to_room(self, requester, room_identifier):
room_member_handler = self.hs.get_room_member_handler()
if RoomID.is_valid(room_identifier):
room_id = room_identifier
elif RoomAlias.is_valid(room_identifier):
room_alias = RoomAlias.from_string(room_identifier)
- room_id, remote_room_hosts = yield room_member_handler.lookup_room_alias(
+ room_id, remote_room_hosts = await room_member_handler.lookup_room_alias(
room_alias
)
room_id = room_id.to_string()
@@ -410,7 +407,7 @@ class RegistrationHandler(BaseHandler):
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
- yield room_member_handler.update_membership(
+ await room_member_handler.update_membership(
requester=requester,
target=requester.user,
room_id=room_id,
@@ -550,8 +547,7 @@ class RegistrationHandler(BaseHandler):
return (device_id, access_token)
- @defer.inlineCallbacks
- def post_registration_actions(self, user_id, auth_result, access_token):
+ async def post_registration_actions(self, user_id, auth_result, access_token):
"""A user has completed registration
Args:
@@ -562,7 +558,7 @@ class RegistrationHandler(BaseHandler):
device, or None if `inhibit_login` enabled.
"""
if self.hs.config.worker_app:
- yield self._post_registration_client(
+ await self._post_registration_client(
user_id=user_id, auth_result=auth_result, access_token=access_token
)
return
@@ -574,19 +570,18 @@ class RegistrationHandler(BaseHandler):
if is_threepid_reserved(
self.hs.config.mau_limits_reserved_threepids, threepid
):
- yield self.store.upsert_monthly_active_user(user_id)
+ await self.store.upsert_monthly_active_user(user_id)
- yield self._register_email_threepid(user_id, threepid, access_token)
+ await self._register_email_threepid(user_id, threepid, access_token)
if auth_result and LoginType.MSISDN in auth_result:
threepid = auth_result[LoginType.MSISDN]
- yield self._register_msisdn_threepid(user_id, threepid)
+ await self._register_msisdn_threepid(user_id, threepid)
if auth_result and LoginType.TERMS in auth_result:
- yield self._on_user_consented(user_id, self.hs.config.user_consent_version)
+ await self._on_user_consented(user_id, self.hs.config.user_consent_version)
- @defer.inlineCallbacks
- def _on_user_consented(self, user_id, consent_version):
+ async def _on_user_consented(self, user_id, consent_version):
"""A user consented to the terms on registration
Args:
@@ -595,8 +590,8 @@ class RegistrationHandler(BaseHandler):
consented to.
"""
logger.info("%s has consented to the privacy policy", user_id)
- yield self.store.user_set_consent_version(user_id, consent_version)
- yield self.post_consent_actions(user_id)
+ await self.store.user_set_consent_version(user_id, consent_version)
+ await self.post_consent_actions(user_id)
@defer.inlineCallbacks
def _register_email_threepid(self, user_id, threepid, token):
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 3d10e4b2d9..73f9eeb399 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -25,8 +25,6 @@ from collections import OrderedDict
from six import iteritems, string_types
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
@@ -103,8 +101,7 @@ class RoomCreationHandler(BaseHandler):
self.third_party_event_rules = hs.get_third_party_event_rules()
- @defer.inlineCallbacks
- def upgrade_room(
+ async def upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion
):
"""Replace a room with a new room with a different version
@@ -117,7 +114,7 @@ class RoomCreationHandler(BaseHandler):
Returns:
Deferred[unicode]: the new room id
"""
- yield self.ratelimit(requester)
+ await self.ratelimit(requester)
user_id = requester.user.to_string()
@@ -138,7 +135,7 @@ class RoomCreationHandler(BaseHandler):
# If this user has sent multiple upgrade requests for the same room
# and one of them is not complete yet, cache the response and
# return it to all subsequent requests
- ret = yield self._upgrade_response_cache.wrap(
+ ret = await self._upgrade_response_cache.wrap(
(old_room_id, user_id),
self._upgrade_room,
requester,
@@ -148,17 +145,16 @@ class RoomCreationHandler(BaseHandler):
return ret
- @defer.inlineCallbacks
- def _upgrade_room(
+ async def _upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion
):
user_id = requester.user.to_string()
# start by allocating a new room id
- r = yield self.store.get_room(old_room_id)
+ r = await self.store.get_room(old_room_id)
if r is None:
raise NotFoundError("Unknown room id %s" % (old_room_id,))
- new_room_id = yield self._generate_room_id(
+ new_room_id = await self._generate_room_id(
creator_id=user_id, is_public=r["is_public"], room_version=new_version,
)
@@ -169,7 +165,7 @@ class RoomCreationHandler(BaseHandler):
(
tombstone_event,
tombstone_context,
- ) = yield self.event_creation_handler.create_event(
+ ) = await self.event_creation_handler.create_event(
requester,
{
"type": EventTypes.Tombstone,
@@ -183,12 +179,12 @@ class RoomCreationHandler(BaseHandler):
},
token_id=requester.access_token_id,
)
- old_room_version = yield self.store.get_room_version_id(old_room_id)
- yield self.auth.check_from_context(
+ old_room_version = await self.store.get_room_version_id(old_room_id)
+ await self.auth.check_from_context(
old_room_version, tombstone_event, tombstone_context
)
- yield self.clone_existing_room(
+ await self.clone_existing_room(
requester,
old_room_id=old_room_id,
new_room_id=new_room_id,
@@ -197,32 +193,31 @@ class RoomCreationHandler(BaseHandler):
)
# now send the tombstone
- yield self.event_creation_handler.send_nonmember_event(
+ await self.event_creation_handler.send_nonmember_event(
requester, tombstone_event, tombstone_context
)
- old_room_state = yield tombstone_context.get_current_state_ids()
+ old_room_state = await tombstone_context.get_current_state_ids()
# update any aliases
- yield self._move_aliases_to_new_room(
+ await self._move_aliases_to_new_room(
requester, old_room_id, new_room_id, old_room_state
)
# Copy over user push rules, tags and migrate room directory state
- yield self.room_member_handler.transfer_room_state_on_room_upgrade(
+ await self.room_member_handler.transfer_room_state_on_room_upgrade(
old_room_id, new_room_id
)
# finally, shut down the PLs in the old room, and update them in the new
# room.
- yield self._update_upgraded_room_pls(
+ await self._update_upgraded_room_pls(
requester, old_room_id, new_room_id, old_room_state,
)
return new_room_id
- @defer.inlineCallbacks
- def _update_upgraded_room_pls(
+ async def _update_upgraded_room_pls(
self,
requester: Requester,
old_room_id: str,
@@ -249,7 +244,7 @@ class RoomCreationHandler(BaseHandler):
)
return
- old_room_pl_state = yield self.store.get_event(old_room_pl_event_id)
+ old_room_pl_state = await self.store.get_event(old_room_pl_event_id)
# we try to stop regular users from speaking by setting the PL required
# to send regular events and invites to 'Moderator' level. That's normally
@@ -278,7 +273,7 @@ class RoomCreationHandler(BaseHandler):
if updated:
try:
- yield self.event_creation_handler.create_and_send_nonmember_event(
+ await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.PowerLevels,
@@ -292,7 +287,7 @@ class RoomCreationHandler(BaseHandler):
except AuthError as e:
logger.warning("Unable to update PLs in old room: %s", e)
- yield self.event_creation_handler.create_and_send_nonmember_event(
+ await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.PowerLevels,
@@ -304,8 +299,7 @@ class RoomCreationHandler(BaseHandler):
ratelimit=False,
)
- @defer.inlineCallbacks
- def clone_existing_room(
+ async def clone_existing_room(
self,
requester: Requester,
old_room_id: str,
@@ -338,7 +332,7 @@ class RoomCreationHandler(BaseHandler):
# Check if old room was non-federatable
# Get old room's create event
- old_room_create_event = yield self.store.get_create_event_for_room(old_room_id)
+ old_room_create_event = await self.store.get_create_event_for_room(old_room_id)
# Check if the create event specified a non-federatable room
if not old_room_create_event.content.get("m.federate", True):
@@ -361,11 +355,11 @@ class RoomCreationHandler(BaseHandler):
(EventTypes.PowerLevels, ""),
)
- old_room_state_ids = yield self.store.get_filtered_current_state_ids(
+ old_room_state_ids = await self.store.get_filtered_current_state_ids(
old_room_id, StateFilter.from_types(types_to_copy)
)
# map from event_id to BaseEvent
- old_room_state_events = yield self.store.get_events(old_room_state_ids.values())
+ old_room_state_events = await self.store.get_events(old_room_state_ids.values())
for k, old_event_id in iteritems(old_room_state_ids):
old_event = old_room_state_events.get(old_event_id)
@@ -400,7 +394,7 @@ class RoomCreationHandler(BaseHandler):
if current_power_level < needed_power_level:
power_levels["users"][user_id] = needed_power_level
- yield self._send_events_for_new_room(
+ await self._send_events_for_new_room(
requester,
new_room_id,
# we expect to override all the presets with initial_state, so this is
@@ -412,12 +406,12 @@ class RoomCreationHandler(BaseHandler):
)
# Transfer membership events
- old_room_member_state_ids = yield self.store.get_filtered_current_state_ids(
+ old_room_member_state_ids = await self.store.get_filtered_current_state_ids(
old_room_id, StateFilter.from_types([(EventTypes.Member, None)])
)
# map from event_id to BaseEvent
- old_room_member_state_events = yield self.store.get_events(
+ old_room_member_state_events = await self.store.get_events(
old_room_member_state_ids.values()
)
for k, old_event in iteritems(old_room_member_state_events):
@@ -426,7 +420,7 @@ class RoomCreationHandler(BaseHandler):
"membership" in old_event.content
and old_event.content["membership"] == "ban"
):
- yield self.room_member_handler.update_membership(
+ await self.room_member_handler.update_membership(
requester,
UserID.from_string(old_event["state_key"]),
new_room_id,
@@ -438,8 +432,7 @@ class RoomCreationHandler(BaseHandler):
# XXX invites/joins
# XXX 3pid invites
- @defer.inlineCallbacks
- def _move_aliases_to_new_room(
+ async def _move_aliases_to_new_room(
self,
requester: Requester,
old_room_id: str,
@@ -448,13 +441,13 @@ class RoomCreationHandler(BaseHandler):
):
directory_handler = self.hs.get_handlers().directory_handler
- aliases = yield self.store.get_aliases_for_room(old_room_id)
+ aliases = await self.store.get_aliases_for_room(old_room_id)
# check to see if we have a canonical alias.
canonical_alias_event = None
canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, ""))
if canonical_alias_event_id:
- canonical_alias_event = yield self.store.get_event(canonical_alias_event_id)
+ canonical_alias_event = await self.store.get_event(canonical_alias_event_id)
# first we try to remove the aliases from the old room (we suppress sending
# the room_aliases event until the end).
@@ -472,7 +465,7 @@ class RoomCreationHandler(BaseHandler):
for alias_str in aliases:
alias = RoomAlias.from_string(alias_str)
try:
- yield directory_handler.delete_association(requester, alias)
+ await directory_handler.delete_association(requester, alias)
removed_aliases.append(alias_str)
except SynapseError as e:
logger.warning("Unable to remove alias %s from old room: %s", alias, e)
@@ -485,7 +478,7 @@ class RoomCreationHandler(BaseHandler):
# we can now add any aliases we successfully removed to the new room.
for alias in removed_aliases:
try:
- yield directory_handler.create_association(
+ await directory_handler.create_association(
requester,
RoomAlias.from_string(alias),
new_room_id,
@@ -502,7 +495,7 @@ class RoomCreationHandler(BaseHandler):
# alias event for the new room with a copy of the information.
try:
if canonical_alias_event:
- yield self.event_creation_handler.create_and_send_nonmember_event(
+ await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.CanonicalAlias,
@@ -518,8 +511,9 @@ class RoomCreationHandler(BaseHandler):
# we returned the new room to the client at this point.
logger.error("Unable to send updated alias events in new room: %s", e)
- @defer.inlineCallbacks
- def create_room(self, requester, config, ratelimit=True, creator_join_profile=None):
+ async def create_room(
+ self, requester, config, ratelimit=True, creator_join_profile=None
+ ):
""" Creates a new room.
Args:
@@ -547,7 +541,7 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
- yield self.auth.check_auth_blocking(user_id)
+ await self.auth.check_auth_blocking(user_id)
if (
self._server_notices_mxid is not None
@@ -556,11 +550,11 @@ class RoomCreationHandler(BaseHandler):
# allow the server notices mxid to create rooms
is_requester_admin = True
else:
- is_requester_admin = yield self.auth.is_server_admin(requester.user)
+ is_requester_admin = await self.auth.is_server_admin(requester.user)
# Check whether the third party rules allows/changes the room create
# request.
- event_allowed = yield self.third_party_event_rules.on_create_room(
+ event_allowed = await self.third_party_event_rules.on_create_room(
requester, config, is_requester_admin=is_requester_admin
)
if not event_allowed:
@@ -574,7 +568,7 @@ class RoomCreationHandler(BaseHandler):
raise SynapseError(403, "You are not permitted to create rooms")
if ratelimit:
- yield self.ratelimit(requester)
+ await self.ratelimit(requester)
room_version_id = config.get(
"room_version", self.config.default_room_version.identifier
@@ -597,7 +591,7 @@ class RoomCreationHandler(BaseHandler):
raise SynapseError(400, "Invalid characters in room alias")
room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname)
- mapping = yield self.store.get_association_from_room_alias(room_alias)
+ mapping = await self.store.get_association_from_room_alias(room_alias)
if mapping:
raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE)
@@ -612,7 +606,7 @@ class RoomCreationHandler(BaseHandler):
except Exception:
raise SynapseError(400, "Invalid user_id: %s" % (i,))
- yield self.event_creation_handler.assert_accepted_privacy_policy(requester)
+ await self.event_creation_handler.assert_accepted_privacy_policy(requester)
power_level_content_override = config.get("power_level_content_override")
if (
@@ -631,13 +625,13 @@ class RoomCreationHandler(BaseHandler):
visibility = config.get("visibility", None)
is_public = visibility == "public"
- room_id = yield self._generate_room_id(
+ room_id = await self._generate_room_id(
creator_id=user_id, is_public=is_public, room_version=room_version,
)
directory_handler = self.hs.get_handlers().directory_handler
if room_alias:
- yield directory_handler.create_association(
+ await directory_handler.create_association(
requester=requester,
room_id=room_id,
room_alias=room_alias,
@@ -670,7 +664,7 @@ class RoomCreationHandler(BaseHandler):
# override any attempt to set room versions via the creation_content
creation_content["room_version"] = room_version.identifier
- yield self._send_events_for_new_room(
+ await self._send_events_for_new_room(
requester,
room_id,
preset_config=preset_config,
@@ -684,7 +678,7 @@ class RoomCreationHandler(BaseHandler):
if "name" in config:
name = config["name"]
- yield self.event_creation_handler.create_and_send_nonmember_event(
+ await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Name,
@@ -698,7 +692,7 @@ class RoomCreationHandler(BaseHandler):
if "topic" in config:
topic = config["topic"]
- yield self.event_creation_handler.create_and_send_nonmember_event(
+ await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Topic,
@@ -716,7 +710,7 @@ class RoomCreationHandler(BaseHandler):
if is_direct:
content["is_direct"] = is_direct
- yield self.room_member_handler.update_membership(
+ await self.room_member_handler.update_membership(
requester,
UserID.from_string(invitee),
room_id,
@@ -730,7 +724,7 @@ class RoomCreationHandler(BaseHandler):
id_access_token = invite_3pid.get("id_access_token") # optional
address = invite_3pid["address"]
medium = invite_3pid["medium"]
- yield self.hs.get_room_member_handler().do_3pid_invite(
+ await self.hs.get_room_member_handler().do_3pid_invite(
room_id,
requester.user,
medium,
@@ -748,8 +742,7 @@ class RoomCreationHandler(BaseHandler):
return result
- @defer.inlineCallbacks
- def _send_events_for_new_room(
+ async def _send_events_for_new_room(
self,
creator, # A Requester object.
room_id,
@@ -769,11 +762,10 @@ class RoomCreationHandler(BaseHandler):
return e
- @defer.inlineCallbacks
- def send(etype, content, **kwargs):
+ async def send(etype, content, **kwargs):
event = create(etype, content, **kwargs)
logger.debug("Sending %s in new room", etype)
- yield self.event_creation_handler.create_and_send_nonmember_event(
+ await self.event_creation_handler.create_and_send_nonmember_event(
creator, event, ratelimit=False
)
@@ -784,10 +776,10 @@ class RoomCreationHandler(BaseHandler):
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
creation_content.update({"creator": creator_id})
- yield send(etype=EventTypes.Create, content=creation_content)
+ await send(etype=EventTypes.Create, content=creation_content)
logger.debug("Sending %s in new room", EventTypes.Member)
- yield self.room_member_handler.update_membership(
+ await self.room_member_handler.update_membership(
creator,
creator.user,
room_id,
@@ -800,7 +792,7 @@ class RoomCreationHandler(BaseHandler):
# of the first events that get sent into a room.
pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None)
if pl_content is not None:
- yield send(etype=EventTypes.PowerLevels, content=pl_content)
+ await send(etype=EventTypes.PowerLevels, content=pl_content)
else:
power_level_content = {
"users": {creator_id: 100},
@@ -833,36 +825,35 @@ class RoomCreationHandler(BaseHandler):
if power_level_content_override:
power_level_content.update(power_level_content_override)
- yield send(etype=EventTypes.PowerLevels, content=power_level_content)
+ await send(etype=EventTypes.PowerLevels, content=power_level_content)
if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
- yield send(
+ await send(
etype=EventTypes.CanonicalAlias,
content={"alias": room_alias.to_string()},
)
if (EventTypes.JoinRules, "") not in initial_state:
- yield send(
+ await send(
etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]}
)
if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
- yield send(
+ await send(
etype=EventTypes.RoomHistoryVisibility,
content={"history_visibility": config["history_visibility"]},
)
if config["guest_can_join"]:
if (EventTypes.GuestAccess, "") not in initial_state:
- yield send(
+ await send(
etype=EventTypes.GuestAccess, content={"guest_access": "can_join"}
)
for (etype, state_key), content in initial_state.items():
- yield send(etype=etype, state_key=state_key, content=content)
+ await send(etype=etype, state_key=state_key, content=content)
- @defer.inlineCallbacks
- def _generate_room_id(
+ async def _generate_room_id(
self, creator_id: str, is_public: str, room_version: RoomVersion,
):
# autogen room IDs and try to create it. We may clash, so just
@@ -874,7 +865,7 @@ class RoomCreationHandler(BaseHandler):
gen_room_id = RoomID(random_string, self.hs.hostname).to_string()
if isinstance(gen_room_id, bytes):
gen_room_id = gen_room_id.decode("utf-8")
- yield self.store.store_room(
+ await self.store.store_room(
room_id=gen_room_id,
room_creator_user_id=creator_id,
is_public=is_public,
@@ -893,8 +884,7 @@ class RoomContextHandler(object):
self.storage = hs.get_storage()
self.state_store = self.storage.state
- @defer.inlineCallbacks
- def get_event_context(self, user, room_id, event_id, limit, event_filter):
+ async def get_event_context(self, user, room_id, event_id, limit, event_filter):
"""Retrieves events, pagination tokens and state around a given event
in a room.
@@ -913,7 +903,7 @@ class RoomContextHandler(object):
before_limit = math.floor(limit / 2.0)
after_limit = limit - before_limit
- users = yield self.store.get_users_in_room(room_id)
+ users = await self.store.get_users_in_room(room_id)
is_peeking = user.to_string() not in users
def filter_evts(events):
@@ -921,17 +911,17 @@ class RoomContextHandler(object):
self.storage, user.to_string(), events, is_peeking=is_peeking
)
- event = yield self.store.get_event(
+ event = await self.store.get_event(
event_id, get_prev_content=True, allow_none=True
)
if not event:
return None
- filtered = yield (filter_evts([event]))
+ filtered = await filter_evts([event])
if not filtered:
raise AuthError(403, "You don't have permission to access that event.")
- results = yield self.store.get_events_around(
+ results = await self.store.get_events_around(
room_id, event_id, before_limit, after_limit, event_filter
)
@@ -939,8 +929,8 @@ class RoomContextHandler(object):
results["events_before"] = event_filter.filter(results["events_before"])
results["events_after"] = event_filter.filter(results["events_after"])
- results["events_before"] = yield filter_evts(results["events_before"])
- results["events_after"] = yield filter_evts(results["events_after"])
+ results["events_before"] = await filter_evts(results["events_before"])
+ results["events_after"] = await filter_evts(results["events_after"])
# filter_evts can return a pruned event in case the user is allowed to see that
# there's something there but not see the content, so use the event that's in
# `filtered` rather than the event we retrieved from the datastore.
@@ -967,7 +957,7 @@ class RoomContextHandler(object):
# first? Shouldn't we be consistent with /sync?
# https://github.com/matrix-org/matrix-doc/issues/687
- state = yield self.state_store.get_state_for_events(
+ state = await self.state_store.get_state_for_events(
[last_event_id], state_filter=state_filter
)
@@ -975,7 +965,7 @@ class RoomContextHandler(object):
if event_filter:
state_events = event_filter.filter(state_events)
- results["state"] = yield filter_evts(state_events)
+ results["state"] = await filter_evts(state_events)
# We use a dummy token here as we only care about the room portion of
# the token, which we replace.
@@ -994,13 +984,12 @@ class RoomEventSource(object):
def __init__(self, hs):
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def get_new_events(
+ async def get_new_events(
self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None
):
# We just ignore the key for now.
- to_key = yield self.get_current_key()
+ to_key = await self.get_current_key()
from_token = RoomStreamToken.parse(from_key)
if from_token.topological:
@@ -1013,11 +1002,11 @@ class RoomEventSource(object):
# See https://github.com/matrix-org/matrix-doc/issues/1144
raise NotImplementedError()
else:
- room_events = yield self.store.get_membership_changes_for_user(
+ room_events = await self.store.get_membership_changes_for_user(
user.to_string(), from_key, to_key
)
- room_to_events = yield self.store.get_room_events_stream_for_rooms(
+ room_to_events = await self.store.get_room_events_stream_for_rooms(
room_ids=room_ids,
from_key=from_key,
to_key=to_key,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index c3ee8db4f0..53b49bc15f 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -142,8 +142,7 @@ class RoomMemberHandler(object):
"""
raise NotImplementedError()
- @defer.inlineCallbacks
- def _local_membership_update(
+ async def _local_membership_update(
self,
requester,
target,
@@ -164,7 +163,7 @@ class RoomMemberHandler(object):
if requester.is_guest:
content["kind"] = "guest"
- event, context = yield self.event_creation_handler.create_event(
+ event, context = await self.event_creation_handler.create_event(
requester,
{
"type": EventTypes.Member,
@@ -182,18 +181,18 @@ class RoomMemberHandler(object):
)
# Check if this event matches the previous membership event for the user.
- duplicate = yield self.event_creation_handler.deduplicate_state_event(
+ duplicate = await self.event_creation_handler.deduplicate_state_event(
event, context
)
if duplicate is not None:
# Discard the new event since this membership change is a no-op.
return duplicate
- yield self.event_creation_handler.handle_new_client_event(
+ await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target], ratelimit=ratelimit
)
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = await context.get_prev_state_ids()
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
@@ -203,15 +202,15 @@ class RoomMemberHandler(object):
# info.
newly_joined = True
if prev_member_event_id:
- prev_member_event = yield self.store.get_event(prev_member_event_id)
+ prev_member_event = await self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
- yield self._user_joined_room(target, room_id)
+ await self._user_joined_room(target, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
- prev_member_event = yield self.store.get_event(prev_member_event_id)
+ prev_member_event = await self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN:
- yield self._user_left_room(target, room_id)
+ await self._user_left_room(target, room_id)
return event
@@ -253,8 +252,7 @@ class RoomMemberHandler(object):
for tag, tag_content in room_tags.items():
yield self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content)
- @defer.inlineCallbacks
- def update_membership(
+ async def update_membership(
self,
requester,
target,
@@ -269,8 +267,8 @@ class RoomMemberHandler(object):
):
key = (room_id,)
- with (yield self.member_linearizer.queue(key)):
- result = yield self._update_membership(
+ with (await self.member_linearizer.queue(key)):
+ result = await self._update_membership(
requester,
target,
room_id,
@@ -285,8 +283,7 @@ class RoomMemberHandler(object):
return result
- @defer.inlineCallbacks
- def _update_membership(
+ async def _update_membership(
self,
requester,
target,
@@ -321,7 +318,7 @@ class RoomMemberHandler(object):
# if this is a join with a 3pid signature, we may need to turn a 3pid
# invite into a normal invite before we can handle the join.
if third_party_signed is not None:
- yield self.federation_handler.exchange_third_party_invite(
+ await self.federation_handler.exchange_third_party_invite(
third_party_signed["sender"],
target.to_string(),
room_id,
@@ -332,7 +329,7 @@ class RoomMemberHandler(object):
remote_room_hosts = []
if effective_membership_state not in ("leave", "ban"):
- is_blocked = yield self.store.is_room_blocked(room_id)
+ is_blocked = await self.store.is_room_blocked(room_id)
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
@@ -351,7 +348,7 @@ class RoomMemberHandler(object):
is_requester_admin = True
else:
- is_requester_admin = yield self.auth.is_server_admin(requester.user)
+ is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin:
if self.config.block_non_admin_invites:
@@ -370,9 +367,9 @@ class RoomMemberHandler(object):
if block_invite:
raise SynapseError(403, "Invites have been disabled on this server")
- latest_event_ids = yield self.store.get_prev_events_for_room(room_id)
+ latest_event_ids = await self.store.get_prev_events_for_room(room_id)
- current_state_ids = yield self.state_handler.get_current_state_ids(
+ current_state_ids = await self.state_handler.get_current_state_ids(
room_id, latest_event_ids=latest_event_ids
)
@@ -381,7 +378,7 @@ class RoomMemberHandler(object):
# transitions and generic otherwise
old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
if old_state_id:
- old_state = yield self.store.get_event(old_state_id, allow_none=True)
+ old_state = await self.store.get_event(old_state_id, allow_none=True)
old_membership = old_state.content.get("membership") if old_state else None
if action == "unban" and old_membership != "ban":
raise SynapseError(
@@ -413,7 +410,7 @@ class RoomMemberHandler(object):
old_membership == Membership.INVITE
and effective_membership_state == Membership.LEAVE
):
- is_blocked = yield self._is_server_notice_room(room_id)
+ is_blocked = await self._is_server_notice_room(room_id)
if is_blocked:
raise SynapseError(
http_client.FORBIDDEN,
@@ -424,18 +421,18 @@ class RoomMemberHandler(object):
if action == "kick":
raise AuthError(403, "The target user is not in the room")
- is_host_in_room = yield self._is_host_in_room(current_state_ids)
+ is_host_in_room = await self._is_host_in_room(current_state_ids)
if effective_membership_state == Membership.JOIN:
if requester.is_guest:
- guest_can_join = yield self._can_guest_join(current_state_ids)
+ guest_can_join = await self._can_guest_join(current_state_ids)
if not guest_can_join:
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
if not is_host_in_room:
- inviter = yield self._get_inviter(target.to_string(), room_id)
+ inviter = await self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain)
@@ -443,13 +440,13 @@ class RoomMemberHandler(object):
profile = self.profile_handler
if not content_specified:
- content["displayname"] = yield profile.get_displayname(target)
- content["avatar_url"] = yield profile.get_avatar_url(target)
+ content["displayname"] = await profile.get_displayname(target)
+ content["avatar_url"] = await profile.get_avatar_url(target)
if requester.is_guest:
content["kind"] = "guest"
- remote_join_response = yield self._remote_join(
+ remote_join_response = await self._remote_join(
requester, remote_room_hosts, room_id, target, content
)
@@ -458,7 +455,7 @@ class RoomMemberHandler(object):
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
# perhaps we've been invited
- inviter = yield self._get_inviter(target.to_string(), room_id)
+ inviter = await self._get_inviter(target.to_string(), room_id)
if not inviter:
raise SynapseError(404, "Not a known room")
@@ -472,12 +469,12 @@ class RoomMemberHandler(object):
else:
# send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain]
- res = yield self._remote_reject_invite(
+ res = await self._remote_reject_invite(
requester, remote_room_hosts, room_id, target, content,
)
return res
- res = yield self._local_membership_update(
+ res = await self._local_membership_update(
requester=requester,
target=target,
room_id=room_id,
@@ -572,8 +569,7 @@ class RoomMemberHandler(object):
)
continue
- @defer.inlineCallbacks
- def send_membership_event(self, requester, event, context, ratelimit=True):
+ async def send_membership_event(self, requester, event, context, ratelimit=True):
"""
Change the membership status of a user in a room.
@@ -599,27 +595,27 @@ class RoomMemberHandler(object):
else:
requester = types.create_requester(target_user)
- prev_event = yield self.event_creation_handler.deduplicate_state_event(
+ prev_event = await self.event_creation_handler.deduplicate_state_event(
event, context
)
if prev_event is not None:
return
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = await context.get_prev_state_ids()
if event.membership == Membership.JOIN:
if requester.is_guest:
- guest_can_join = yield self._can_guest_join(prev_state_ids)
+ guest_can_join = await self._can_guest_join(prev_state_ids)
if not guest_can_join:
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
if event.membership not in (Membership.LEAVE, Membership.BAN):
- is_blocked = yield self.store.is_room_blocked(room_id)
+ is_blocked = await self.store.is_room_blocked(room_id)
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
- yield self.event_creation_handler.handle_new_client_event(
+ await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target_user], ratelimit=ratelimit
)
@@ -633,15 +629,15 @@ class RoomMemberHandler(object):
# info.
newly_joined = True
if prev_member_event_id:
- prev_member_event = yield self.store.get_event(prev_member_event_id)
+ prev_member_event = await self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
- yield self._user_joined_room(target_user, room_id)
+ await self._user_joined_room(target_user, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
- prev_member_event = yield self.store.get_event(prev_member_event_id)
+ prev_member_event = await self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN:
- yield self._user_left_room(target_user, room_id)
+ await self._user_left_room(target_user, room_id)
@defer.inlineCallbacks
def _can_guest_join(self, current_state_ids):
@@ -699,8 +695,7 @@ class RoomMemberHandler(object):
if invite:
return UserID.from_string(invite.sender)
- @defer.inlineCallbacks
- def do_3pid_invite(
+ async def do_3pid_invite(
self,
room_id,
inviter,
@@ -712,7 +707,7 @@ class RoomMemberHandler(object):
id_access_token=None,
):
if self.config.block_non_admin_invites:
- is_requester_admin = yield self.auth.is_server_admin(requester.user)
+ is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin:
raise SynapseError(
403, "Invites have been disabled on this server", Codes.FORBIDDEN
@@ -720,9 +715,9 @@ class RoomMemberHandler(object):
# We need to rate limit *before* we send out any 3PID invites, so we
# can't just rely on the standard ratelimiting of events.
- yield self.base_handler.ratelimit(requester)
+ await self.base_handler.ratelimit(requester)
- can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited(
+ can_invite = await self.third_party_event_rules.check_threepid_can_be_invited(
medium, address, room_id
)
if not can_invite:
@@ -737,16 +732,16 @@ class RoomMemberHandler(object):
403, "Looking up third-party identifiers is denied from this server"
)
- invitee = yield self.identity_handler.lookup_3pid(
+ invitee = await self.identity_handler.lookup_3pid(
id_server, medium, address, id_access_token
)
if invitee:
- yield self.update_membership(
+ await self.update_membership(
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
)
else:
- yield self._make_and_store_3pid_invite(
+ await self._make_and_store_3pid_invite(
requester,
id_server,
medium,
@@ -757,8 +752,7 @@ class RoomMemberHandler(object):
id_access_token=id_access_token,
)
- @defer.inlineCallbacks
- def _make_and_store_3pid_invite(
+ async def _make_and_store_3pid_invite(
self,
requester,
id_server,
@@ -769,7 +763,7 @@ class RoomMemberHandler(object):
txn_id,
id_access_token=None,
):
- room_state = yield self.state_handler.get_current_state(room_id)
+ room_state = await self.state_handler.get_current_state(room_id)
inviter_display_name = ""
inviter_avatar_url = ""
@@ -807,7 +801,7 @@ class RoomMemberHandler(object):
public_keys,
fallback_public_key,
display_name,
- ) = yield self.identity_handler.ask_id_server_for_third_party_invite(
+ ) = await self.identity_handler.ask_id_server_for_third_party_invite(
requester=requester,
id_server=id_server,
medium=medium,
@@ -823,7 +817,7 @@ class RoomMemberHandler(object):
id_access_token=id_access_token,
)
- yield self.event_creation_handler.create_and_send_nonmember_event(
+ await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.ThirdPartyInvite,
@@ -917,8 +911,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return complexity["v1"] > max_complexity
- @defer.inlineCallbacks
- def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
+ async def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
"""Implements RoomMemberHandler._remote_join
"""
# filter ourselves out of remote_room_hosts: do_invite_join ignores it
@@ -933,7 +926,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
if self.hs.config.limit_remote_rooms.enabled:
# Fetch the room complexity
- too_complex = yield self._is_remote_room_too_complex(
+ too_complex = await self._is_remote_room_too_complex(
room_id, remote_room_hosts
)
if too_complex is True:
@@ -947,12 +940,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
- yield defer.ensureDeferred(
- self.federation_handler.do_invite_join(
- remote_room_hosts, room_id, user.to_string(), content
- )
+ await self.federation_handler.do_invite_join(
+ remote_room_hosts, room_id, user.to_string(), content
)
- yield self._user_joined_room(user, room_id)
+ await self._user_joined_room(user, room_id)
# Check the room we just joined wasn't too large, if we didn't fetch the
# complexity of it before.
@@ -962,7 +953,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return
# Check again, but with the local state events
- too_complex = yield self._is_local_room_too_complex(room_id)
+ too_complex = await self._is_local_room_too_complex(room_id)
if too_complex is False:
# We're under the limit.
@@ -970,7 +961,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# The room is too large. Leave.
requester = types.create_requester(user, None, False, None)
- yield self.update_membership(
+ await self.update_membership(
requester=requester, target=user, room_id=room_id, action="leave"
)
raise SynapseError(
@@ -1008,12 +999,12 @@ class RoomMemberMasterHandler(RoomMemberHandler):
def _user_joined_room(self, target, room_id):
"""Implements RoomMemberHandler._user_joined_room
"""
- return user_joined_room(self.distributor, target, room_id)
+ return defer.succeed(user_joined_room(self.distributor, target, room_id))
def _user_left_room(self, target, room_id):
"""Implements RoomMemberHandler._user_left_room
"""
- return user_left_room(self.distributor, target, room_id)
+ return defer.succeed(user_left_room(self.distributor, target, room_id))
@defer.inlineCallbacks
def forget(self, user, room_id):
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 88a5a97caf..71d9ed62b0 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -273,10 +273,9 @@ class Notifier(object):
"room_key", room_stream_id, users=extra_users, rooms=[event.room_id]
)
- @defer.inlineCallbacks
- def _notify_app_services(self, room_stream_id):
+ async def _notify_app_services(self, room_stream_id):
try:
- yield self.appservice_handler.notify_interested_services(room_stream_id)
+ await self.appservice_handler.notify_interested_services(room_stream_id)
except Exception:
logger.exception("Error notifying application services of event")
@@ -475,20 +474,18 @@ class Notifier(object):
return result
- @defer.inlineCallbacks
- def _get_room_ids(self, user, explicit_room_id):
- joined_room_ids = yield self.store.get_rooms_for_user(user.to_string())
+ async def _get_room_ids(self, user, explicit_room_id):
+ joined_room_ids = await self.store.get_rooms_for_user(user.to_string())
if explicit_room_id:
if explicit_room_id in joined_room_ids:
return [explicit_room_id], True
- if (yield self._is_world_readable(explicit_room_id)):
+ if await self._is_world_readable(explicit_room_id):
return [explicit_room_id], False
raise AuthError(403, "Non-joined access not allowed")
return joined_room_ids, True
- @defer.inlineCallbacks
- def _is_world_readable(self, room_id):
- state = yield self.state_handler.get_current_state(
+ async def _is_world_readable(self, room_id):
+ state = await self.state_handler.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if state and "history_visibility" in state.content:
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 1be1ccbdf3..f88c80ae84 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -16,6 +16,7 @@
import abc
import logging
import re
+from inspect import signature
from typing import Dict, List, Tuple
from six import raise_from
@@ -60,6 +61,8 @@ class ReplicationEndpoint(object):
must call `register` to register the path with the HTTP server.
Requests can be sent by calling the client returned by `make_client`.
+ Requests are sent to master process by default, but can be sent to other
+ named processes by specifying an `instance_name` keyword argument.
Attributes:
NAME (str): A name for the endpoint, added to the path as well as used
@@ -91,6 +94,16 @@ class ReplicationEndpoint(object):
hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
)
+ # We reserve `instance_name` as a parameter to sending requests, so we
+ # assert here that sub classes don't try and use the name.
+ assert (
+ "instance_name" not in self.PATH_ARGS
+ ), "`instance_name` is a reserved paramater name"
+ assert (
+ "instance_name"
+ not in signature(self.__class__._serialize_payload).parameters
+ ), "`instance_name` is a reserved paramater name"
+
assert self.METHOD in ("PUT", "POST", "GET")
@abc.abstractmethod
@@ -135,7 +148,11 @@ class ReplicationEndpoint(object):
@trace(opname="outgoing_replication_request")
@defer.inlineCallbacks
- def send_request(**kwargs):
+ def send_request(instance_name="master", **kwargs):
+ # Currently we only support sending requests to master process.
+ if instance_name != "master":
+ raise Exception("Unknown instance")
+
data = yield cls._serialize_payload(**kwargs)
url_args = [
diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py
index f35cebc710..0459f582bf 100644
--- a/synapse/replication/http/streams.py
+++ b/synapse/replication/http/streams.py
@@ -50,6 +50,8 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
def __init__(self, hs):
super().__init__(hs)
+ self._instance_name = hs.get_instance_name()
+
# We pull the streams from the replication steamer (if we try and make
# them ourselves we end up in an import loop).
self.streams = hs.get_replication_streamer().get_streams()
@@ -67,7 +69,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
upto_token = parse_integer(request, "upto_token", required=True)
updates, upto_token, limited = await stream.get_updates_since(
- from_token, upto_token
+ self._instance_name, from_token, upto_token
)
return (
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 751c799d94..5d7c8871a4 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Dict, Optional
+from typing import Optional
import six
@@ -49,19 +49,6 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
self.hs = hs
- def stream_positions(self) -> Dict[str, int]:
- """
- Get the current positions of all the streams this store wants to subscribe to
-
- Returns:
- map from stream name to the most recent update we have for
- that stream (ie, the point we want to start replicating from)
- """
- pos = {}
- if self._cache_id_gen:
- pos["caches"] = self._cache_id_gen.get_current_token()
- return pos
-
def get_cache_stream_token(self):
if self._cache_id_gen:
return self._cache_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index ebe94909cb..65e54b1c71 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -32,14 +32,6 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token()
- def stream_positions(self):
- result = super(SlavedAccountDataStore, self).stream_positions()
- position = self._account_data_id_gen.get_current_token()
- result["user_account_data"] = position
- result["room_account_data"] = position
- result["tag_account_data"] = position
- return result
-
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "tag_account_data":
self._account_data_id_gen.advance(token)
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 0c237c6e0f..c923751e50 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -43,11 +43,6 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
expiry_ms=30 * 60 * 1000,
)
- def stream_positions(self):
- result = super(SlavedDeviceInboxStore, self).stream_positions()
- result["to_device"] = self._device_inbox_id_gen.get_current_token()
- return result
-
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "to_device":
self._device_inbox_id_gen.advance(token)
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 23b1650e41..58fb0eaae3 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -48,16 +48,6 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
"DeviceListFederationStreamChangeCache", device_list_max
)
- def stream_positions(self):
- result = super(SlavedDeviceStore, self).stream_positions()
- # The user signature stream uses the same stream ID generator as the
- # device list stream, so set them both to the device list ID
- # generator's current token.
- current_token = self._device_list_id_gen.get_current_token()
- result[DeviceListsStream.NAME] = current_token
- result[UserSignatureStream.NAME] = current_token
- return result
-
def process_replication_rows(self, stream_name, token, rows):
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token)
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index e73342c657..15011259df 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -93,12 +93,6 @@ class SlavedEventStore(
def get_room_min_stream_ordering(self):
return self._backfill_id_gen.get_current_token()
- def stream_positions(self):
- result = super(SlavedEventStore, self).stream_positions()
- result["events"] = self._stream_id_gen.get_current_token()
- result["backfill"] = -self._backfill_id_gen.get_current_token()
- return result
-
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "events":
self._stream_id_gen.advance(token)
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 2d4fd08cf5..01bcf0e882 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -37,11 +37,6 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
def get_group_stream_token(self):
return self._group_updates_id_gen.get_current_token()
- def stream_positions(self):
- result = super(SlavedGroupServerStore, self).stream_positions()
- result["groups"] = self._group_updates_id_gen.get_current_token()
- return result
-
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "groups":
self._group_updates_id_gen.advance(token)
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index ad8f0c15a9..fae3125072 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -41,15 +41,6 @@ class SlavedPresenceStore(BaseSlavedStore):
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
- def stream_positions(self):
- result = super(SlavedPresenceStore, self).stream_positions()
-
- if self.hs.config.use_presence:
- position = self._presence_id_gen.get_current_token()
- result["presence"] = position
-
- return result
-
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "presence":
self._presence_id_gen.advance(token)
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index eebd5a1fb6..6138796da4 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -37,11 +37,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def get_max_push_rules_stream_id(self):
return self._push_rules_stream_id_gen.get_current_token()
- def stream_positions(self):
- result = super(SlavedPushRuleStore, self).stream_positions()
- result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
- return result
-
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "push_rules":
self._push_rules_stream_id_gen.advance(token)
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index bce8a3d115..67be337945 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -28,11 +28,6 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
)
- def stream_positions(self):
- result = super(SlavedPusherStore, self).stream_positions()
- result["pushers"] = self._pushers_id_gen.get_current_token()
- return result
-
def get_pushers_stream_token(self):
return self._pushers_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index d40dc6e1f5..993432edcb 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -42,11 +42,6 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
- def stream_positions(self):
- result = super(SlavedReceiptsStore, self).stream_positions()
- result["receipts"] = self._receipts_id_gen.get_current_token()
- return result
-
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type))
self._get_linearized_receipts_for_room.invalidate_many((room_id,))
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 3a20f45316..10dda8708f 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -30,11 +30,6 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token()
- def stream_positions(self):
- result = super(RoomStore, self).stream_positions()
- result["public_rooms"] = self._public_room_id_gen.get_current_token()
- return result
-
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "public_rooms":
self._public_room_id_gen.advance(token)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 2d07b8b2d0..3bbf3c3569 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -16,7 +16,7 @@
"""
import logging
-from typing import TYPE_CHECKING, Dict
+from typing import TYPE_CHECKING
from twisted.internet.protocol import ReconnectingClientFactory
@@ -86,37 +86,22 @@ class ReplicationDataHandler:
def __init__(self, store: BaseSlavedStore):
self.store = store
- async def on_rdata(self, stream_name: str, token: int, rows: list):
+ async def on_rdata(
+ self, stream_name: str, instance_name: str, token: int, rows: list
+ ):
"""Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to
handle more.
Args:
- stream_name (str): name of the replication stream for this batch of rows
- token (int): stream token for this batch of rows
- rows (list): a list of Stream.ROW_TYPE objects as returned by
- Stream.parse_row.
+ stream_name: name of the replication stream for this batch of rows
+ instance_name: the instance that wrote the rows.
+ token: stream token for this batch of rows
+ rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
"""
self.store.process_replication_rows(stream_name, token, rows)
- def get_streams_to_replicate(self) -> Dict[str, int]:
- """Called when a new connection has been established and we need to
- subscribe to streams.
-
- Returns:
- map from stream name to the most recent update we have for
- that stream (ie, the point we want to start replicating from)
- """
- args = self.store.stream_positions()
- user_account_data = args.pop("user_account_data", None)
- room_account_data = args.pop("room_account_data", None)
- if user_account_data:
- args["account_data"] = user_account_data
- elif room_account_data:
- args["account_data"] = room_account_data
- return args
-
async def on_position(self, stream_name: str, token: int):
self.store.process_replication_rows(stream_name, token, [])
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 6f7054d5af..2d1d119c7c 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -278,19 +278,24 @@ class ReplicationCommandHandler:
# Check if this is the last of a batch of updates
rows = self._pending_batches.pop(stream_name, [])
rows.append(row)
- await self.on_rdata(stream_name, cmd.token, rows)
+ await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
- async def on_rdata(self, stream_name: str, token: int, rows: list):
+ async def on_rdata(
+ self, stream_name: str, instance_name: str, token: int, rows: list
+ ):
"""Called to handle a batch of replication data with a given stream token.
Args:
stream_name: name of the replication stream for this batch of rows
+ instance_name: the instance that wrote the rows.
token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
logger.debug("Received rdata %s -> %s", stream_name, token)
- await self._replication_data_handler.on_rdata(stream_name, token, rows)
+ await self._replication_data_handler.on_rdata(
+ stream_name, instance_name, token, rows
+ )
async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
if cmd.instance_name == self._instance_name:
@@ -314,15 +319,7 @@ class ReplicationCommandHandler:
self._pending_batches.pop(cmd.stream_name, [])
# Find where we previously streamed up to.
- current_token = self._replication_data_handler.get_streams_to_replicate().get(
- cmd.stream_name
- )
- if current_token is None:
- logger.warning(
- "Got POSITION for stream we're not subscribed to: %s",
- cmd.stream_name,
- )
- return
+ current_token = stream.current_token()
# If the position token matches our current token then we're up to
# date and there's nothing to do. Otherwise, fetch all updates
@@ -333,7 +330,9 @@ class ReplicationCommandHandler:
updates,
current_token,
missing_updates,
- ) = await stream.get_updates_since(current_token, cmd.token)
+ ) = await stream.get_updates_since(
+ cmd.instance_name, current_token, cmd.token
+ )
# TODO: add some tests for this
@@ -342,7 +341,10 @@ class ReplicationCommandHandler:
for token, rows in _batch_updates(updates):
await self.on_rdata(
- cmd.stream_name, token, [stream.parse_row(row) for row in rows],
+ cmd.stream_name,
+ cmd.instance_name,
+ token,
+ [stream.parse_row(row) for row in rows],
)
# We've now caught up to position sent to us, notify handler.
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 617e860f95..41c623d737 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -61,6 +61,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
outbound_redis_connection = None # type: txredisapi.RedisProtocol
def connectionMade(self):
+ super().connectionMade()
logger.info("Connected to redis instance")
self.subscribe(self.stream_name)
self.send_command(ReplicateCommand())
@@ -119,6 +120,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
logger.warning("Unhandled command: %r", cmd)
def connectionLost(self, reason):
+ super().connectionLost(reason)
logger.info("Lost connection to redis instance")
self.handler.lost_connection(self)
@@ -189,5 +191,6 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
p.handler = self.handler
p.outbound_redis_connection = self.outbound_redis_connection
p.stream_name = self.stream_name
+ p.password = self.password
return p
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 33d2f589ac..b690abedad 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -80,7 +80,7 @@ class ReplicationStreamer(object):
for stream in STREAMS_MAP.values():
if stream == FederationStream and hs.config.send_federation:
# We only support federation stream if federation sending
- # hase been disabled on the master.
+ # has been disabled on the master.
continue
self.streams.append(stream(hs))
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 4af1afd119..084604e8b0 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -16,7 +16,7 @@
import logging
from collections import namedtuple
-from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple
+from typing import Any, Awaitable, Callable, List, Optional, Tuple
import attr
@@ -53,6 +53,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
#
# The arguments are:
#
+# * instance_name: the writer of the stream
# * from_token: the previous stream token: the starting point for fetching the
# updates
# * to_token: the new stream token: the point to get updates up to
@@ -62,7 +63,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
# If there are more updates available, it should set `limited` in the result, and
# it will be called again to get the next batch.
#
-UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]]
+UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]]
class Stream(object):
@@ -93,6 +94,7 @@ class Stream(object):
def __init__(
self,
+ local_instance_name: str,
current_token_function: Callable[[], Token],
update_function: UpdateFunction,
):
@@ -102,15 +104,18 @@ class Stream(object):
implemented by subclasses.
current_token_function is called to get the current token of the underlying
- stream.
+ stream. It is only meaningful on the process that is the source of the
+ replication stream (ie, usually the master).
update_function is called to get updates for this stream between a pair of
stream tokens. See the UpdateFunction type definition for more info.
Args:
+ local_instance_name: The instance name of the current process
current_token_function: callback to get the current token, as above
update_function: callback go get stream updates, as above
"""
+ self.local_instance_name = local_instance_name
self.current_token = current_token_function
self.update_function = update_function
@@ -135,14 +140,14 @@ class Stream(object):
"""
current_token = self.current_token()
updates, current_token, limited = await self.get_updates_since(
- self.last_token, current_token
+ self.local_instance_name, self.last_token, current_token
)
self.last_token = current_token
return updates, current_token, limited
async def get_updates_since(
- self, from_token: Token, upto_token: Token
+ self, instance_name: str, from_token: Token, upto_token: Token
) -> StreamUpdateResult:
"""Like get_updates except allows specifying from when we should
stream updates
@@ -160,19 +165,19 @@ class Stream(object):
return [], upto_token, False
updates, upto_token, limited = await self.update_function(
- from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
+ instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
)
return updates, upto_token, limited
def db_query_to_update_function(
- query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]]
+ query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
) -> UpdateFunction:
"""Wraps a db query function which returns a list of rows to make it
suitable for use as an `update_function` for the Stream class
"""
- async def update_function(from_token, upto_token, limit):
+ async def update_function(instance_name, from_token, upto_token, limit):
rows = await query_function(from_token, upto_token, limit)
updates = [(row[0], row[1:]) for row in rows]
limited = False
@@ -193,10 +198,13 @@ def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
client = ReplicationGetStreamUpdates.make_client(hs)
async def update_function(
- from_token: int, upto_token: int, limit: int
+ instance_name: str, from_token: int, upto_token: int, limit: int
) -> StreamUpdateResult:
result = await client(
- stream_name=stream_name, from_token=from_token, upto_token=upto_token,
+ instance_name=instance_name,
+ stream_name=stream_name,
+ from_token=from_token,
+ upto_token=upto_token,
)
return result["updates"], result["upto_token"], result["limited"]
@@ -226,6 +234,7 @@ class BackfillStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
+ hs.get_instance_name(),
store.get_current_backfill_token,
db_query_to_update_function(store.get_all_new_backfill_event_rows),
)
@@ -261,7 +270,9 @@ class PresenceStream(Stream):
# Query master process
update_function = make_http_update_function(hs, self.NAME)
- super().__init__(store.get_current_presence_token, update_function)
+ super().__init__(
+ hs.get_instance_name(), store.get_current_presence_token, update_function
+ )
class TypingStream(Stream):
@@ -284,7 +295,9 @@ class TypingStream(Stream):
# Query master process
update_function = make_http_update_function(hs, self.NAME)
- super().__init__(typing_handler.get_current_token, update_function)
+ super().__init__(
+ hs.get_instance_name(), typing_handler.get_current_token, update_function
+ )
class ReceiptsStream(Stream):
@@ -305,6 +318,7 @@ class ReceiptsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
+ hs.get_instance_name(),
store.get_max_receipt_stream_id,
db_query_to_update_function(store.get_all_updated_receipts),
)
@@ -322,14 +336,16 @@ class PushRulesStream(Stream):
def __init__(self, hs):
self.store = hs.get_datastore()
super(PushRulesStream, self).__init__(
- self._current_token, self._update_function
+ hs.get_instance_name(), self._current_token, self._update_function
)
def _current_token(self) -> int:
push_rules_token, _ = self.store.get_push_rules_stream_token()
return push_rules_token
- async def _update_function(self, from_token: Token, to_token: Token, limit: int):
+ async def _update_function(
+ self, instance_name: str, from_token: Token, to_token: Token, limit: int
+ ):
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
limited = False
@@ -356,6 +372,7 @@ class PushersStream(Stream):
store = hs.get_datastore()
super().__init__(
+ hs.get_instance_name(),
store.get_pushers_stream_token,
db_query_to_update_function(store.get_all_updated_pushers_rows),
)
@@ -387,6 +404,7 @@ class CachesStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
+ hs.get_instance_name(),
store.get_cache_stream_token,
db_query_to_update_function(store.get_all_updated_caches),
)
@@ -412,6 +430,7 @@ class PublicRoomsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
+ hs.get_instance_name(),
store.get_current_public_room_stream_id,
db_query_to_update_function(store.get_all_new_public_rooms),
)
@@ -432,6 +451,7 @@ class DeviceListsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
+ hs.get_instance_name(),
store.get_device_stream_token,
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
)
@@ -449,6 +469,7 @@ class ToDeviceStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
+ hs.get_instance_name(),
store.get_to_device_stream_token,
db_query_to_update_function(store.get_all_new_device_messages),
)
@@ -468,6 +489,7 @@ class TagAccountDataStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
+ hs.get_instance_name(),
store.get_max_account_data_stream_id,
db_query_to_update_function(store.get_all_updated_tags),
)
@@ -487,6 +509,7 @@ class AccountDataStream(Stream):
def __init__(self, hs):
self.store = hs.get_datastore()
super().__init__(
+ hs.get_instance_name(),
self.store.get_max_account_data_stream_id,
db_query_to_update_function(self._update_function),
)
@@ -517,6 +540,7 @@ class GroupServerStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
+ hs.get_instance_name(),
store.get_group_stream_token,
db_query_to_update_function(store.get_all_groups_changes),
)
@@ -534,6 +558,7 @@ class UserSignatureStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
+ hs.get_instance_name(),
store.get_device_stream_token,
db_query_to_update_function(
store.get_all_user_signature_changes_for_remotes
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 52df81b1bd..890e75d827 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -118,11 +118,17 @@ class EventsStream(Stream):
def __init__(self, hs):
self._store = hs.get_datastore()
super().__init__(
- self._store.get_current_events_token, self._update_function,
+ hs.get_instance_name(),
+ self._store.get_current_events_token,
+ self._update_function,
)
async def _update_function(
- self, from_token: Token, current_token: Token, target_row_count: int
+ self,
+ instance_name: str,
+ from_token: Token,
+ current_token: Token,
+ target_row_count: int,
) -> StreamUpdateResult:
# the events stream merges together three separate sources:
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 75133d7e40..b0505b8a2c 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -15,7 +15,7 @@
# limitations under the License.
from collections import namedtuple
-from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
+from synapse.replication.tcp.streams._base import Stream, make_http_update_function
class FederationStream(Stream):
@@ -35,21 +35,33 @@ class FederationStream(Stream):
ROW_TYPE = FederationStreamRow
def __init__(self, hs):
- # Not all synapse instances will have a federation sender instance,
- # whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
- # so we stub the stream out when that is the case.
- if hs.config.worker_app is None or hs.should_send_federation():
+ if hs.config.worker_app is None:
+ # master process: get updates from the FederationRemoteSendQueue.
+ # (if the master is configured to send federation itself, federation_sender
+ # will be a real FederationSender, which has stubs for current_token and
+ # get_replication_rows.)
federation_sender = hs.get_federation_sender()
current_token = federation_sender.get_current_token
- update_function = db_query_to_update_function(
- federation_sender.get_replication_rows
- )
+ update_function = federation_sender.get_replication_rows
+
+ elif hs.should_send_federation():
+ # federation sender: Query master process
+ update_function = make_http_update_function(hs, self.NAME)
+ current_token = self._stub_current_token
+
else:
- current_token = lambda: 0
+ # other worker: stub out the update function (we're not interested in
+ # any updates so when we get a POSITION we do nothing)
update_function = self._stub_update_function
+ current_token = self._stub_current_token
- super().__init__(current_token, update_function)
+ super().__init__(hs.get_instance_name(), current_token, update_function)
+
+ @staticmethod
+ def _stub_current_token():
+ # dummy current-token method for use on workers
+ return 0
@staticmethod
- async def _stub_update_function(from_token, upto_token, limit):
+ async def _stub_update_function(instance_name, from_token, upto_token, limit):
return [], upto_token, False
diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py
index 5736c56032..3bf330da49 100644
--- a/synapse/server_notices/consent_server_notices.py
+++ b/synapse/server_notices/consent_server_notices.py
@@ -16,8 +16,6 @@ import logging
from six import iteritems, string_types
-from twisted.internet import defer
-
from synapse.api.errors import SynapseError
from synapse.api.urls import ConsentURIBuilder
from synapse.config import ConfigError
@@ -59,8 +57,7 @@ class ConsentServerNotices(object):
self._consent_uri_builder = ConsentURIBuilder(hs.config)
- @defer.inlineCallbacks
- def maybe_send_server_notice_to_user(self, user_id):
+ async def maybe_send_server_notice_to_user(self, user_id):
"""Check if we need to send a notice to this user, and does so if so
Args:
@@ -78,7 +75,7 @@ class ConsentServerNotices(object):
return
self._users_in_progress.add(user_id)
try:
- u = yield self._store.get_user_by_id(user_id)
+ u = await self._store.get_user_by_id(user_id)
if u["is_guest"] and not self._send_to_guests:
# don't send to guests
@@ -100,8 +97,8 @@ class ConsentServerNotices(object):
content = copy_with_str_subst(
self._server_notice_content, {"consent_uri": consent_uri}
)
- yield self._server_notices_manager.send_notice(user_id, content)
- yield self._store.user_set_consent_server_notice_sent(
+ await self._server_notices_manager.send_notice(user_id, content)
+ await self._store.user_set_consent_server_notice_sent(
user_id, self._current_consent_version
)
except SynapseError as e:
diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index ce4a828894..d97166351e 100644
--- a/synapse/server_notices/resource_limits_server_notices.py
+++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -16,8 +16,6 @@ import logging
from six import iteritems
-from twisted.internet import defer
-
from synapse.api.constants import (
EventTypes,
LimitBlockingTypes,
@@ -50,8 +48,7 @@ class ResourceLimitsServerNotices(object):
self._notifier = hs.get_notifier()
- @defer.inlineCallbacks
- def maybe_send_server_notice_to_user(self, user_id):
+ async def maybe_send_server_notice_to_user(self, user_id):
"""Check if we need to send a notice to this user, this will be true in
two cases.
1. The server has reached its limit does not reflect this
@@ -74,13 +71,13 @@ class ResourceLimitsServerNotices(object):
# Don't try and send server notices unless they've been enabled
return
- timestamp = yield self._store.user_last_seen_monthly_active(user_id)
+ timestamp = await self._store.user_last_seen_monthly_active(user_id)
if timestamp is None:
# This user will be blocked from receiving the notice anyway.
# In practice, not sure we can ever get here
return
- room_id = yield self._server_notices_manager.get_or_create_notice_room_for_user(
+ room_id = await self._server_notices_manager.get_or_create_notice_room_for_user(
user_id
)
@@ -88,10 +85,10 @@ class ResourceLimitsServerNotices(object):
logger.warning("Failed to get server notices room")
return
- yield self._check_and_set_tags(user_id, room_id)
+ await self._check_and_set_tags(user_id, room_id)
# Determine current state of room
- currently_blocked, ref_events = yield self._is_room_currently_blocked(room_id)
+ currently_blocked, ref_events = await self._is_room_currently_blocked(room_id)
limit_msg = None
limit_type = None
@@ -99,7 +96,7 @@ class ResourceLimitsServerNotices(object):
# Normally should always pass in user_id to check_auth_blocking
# if you have it, but in this case are checking what would happen
# to other users if they were to arrive.
- yield self._auth.check_auth_blocking()
+ await self._auth.check_auth_blocking()
except ResourceLimitError as e:
limit_msg = e.msg
limit_type = e.limit_type
@@ -112,22 +109,21 @@ class ResourceLimitsServerNotices(object):
# We have hit the MAU limit, but MAU alerting is disabled:
# reset room if necessary and return
if currently_blocked:
- self._remove_limit_block_notification(user_id, ref_events)
+ await self._remove_limit_block_notification(user_id, ref_events)
return
if currently_blocked and not limit_msg:
# Room is notifying of a block, when it ought not to be.
- yield self._remove_limit_block_notification(user_id, ref_events)
+ await self._remove_limit_block_notification(user_id, ref_events)
elif not currently_blocked and limit_msg:
# Room is not notifying of a block, when it ought to be.
- yield self._apply_limit_block_notification(
+ await self._apply_limit_block_notification(
user_id, limit_msg, limit_type
)
except SynapseError as e:
logger.error("Error sending resource limits server notice: %s", e)
- @defer.inlineCallbacks
- def _remove_limit_block_notification(self, user_id, ref_events):
+ async def _remove_limit_block_notification(self, user_id, ref_events):
"""Utility method to remove limit block notifications from the server
notices room.
@@ -137,12 +133,13 @@ class ResourceLimitsServerNotices(object):
limit blocking and need to be preserved.
"""
content = {"pinned": ref_events}
- yield self._server_notices_manager.send_notice(
+ await self._server_notices_manager.send_notice(
user_id, content, EventTypes.Pinned, ""
)
- @defer.inlineCallbacks
- def _apply_limit_block_notification(self, user_id, event_body, event_limit_type):
+ async def _apply_limit_block_notification(
+ self, user_id, event_body, event_limit_type
+ ):
"""Utility method to apply limit block notifications in the server
notices room.
@@ -159,17 +156,16 @@ class ResourceLimitsServerNotices(object):
"admin_contact": self._config.admin_contact,
"limit_type": event_limit_type,
}
- event = yield self._server_notices_manager.send_notice(
+ event = await self._server_notices_manager.send_notice(
user_id, content, EventTypes.Message
)
content = {"pinned": [event.event_id]}
- yield self._server_notices_manager.send_notice(
+ await self._server_notices_manager.send_notice(
user_id, content, EventTypes.Pinned, ""
)
- @defer.inlineCallbacks
- def _check_and_set_tags(self, user_id, room_id):
+ async def _check_and_set_tags(self, user_id, room_id):
"""
Since server notices rooms were originally not with tags,
important to check that tags have been set correctly
@@ -177,20 +173,19 @@ class ResourceLimitsServerNotices(object):
user_id(str): the user in question
room_id(str): the server notices room for that user
"""
- tags = yield self._store.get_tags_for_room(user_id, room_id)
+ tags = await self._store.get_tags_for_room(user_id, room_id)
need_to_set_tag = True
if tags:
if SERVER_NOTICE_ROOM_TAG in tags:
# tag already present, nothing to do here
need_to_set_tag = False
if need_to_set_tag:
- max_id = yield self._store.add_tag_to_room(
+ max_id = await self._store.add_tag_to_room(
user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
)
self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
- @defer.inlineCallbacks
- def _is_room_currently_blocked(self, room_id):
+ async def _is_room_currently_blocked(self, room_id):
"""
Determines if the room is currently blocked
@@ -198,7 +193,7 @@ class ResourceLimitsServerNotices(object):
room_id(str): The room id of the server notices room
Returns:
-
+ Deferred[Tuple[bool, List]]:
bool: Is the room currently blocked
list: The list of pinned events that are unrelated to limit blocking
This list can be used as a convenience in the case where the block
@@ -208,7 +203,7 @@ class ResourceLimitsServerNotices(object):
currently_blocked = False
pinned_state_event = None
try:
- pinned_state_event = yield self._state.get_current_state(
+ pinned_state_event = await self._state.get_current_state(
room_id, event_type=EventTypes.Pinned
)
except AuthError:
@@ -219,7 +214,7 @@ class ResourceLimitsServerNotices(object):
if pinned_state_event is not None:
referenced_events = list(pinned_state_event.content.get("pinned", []))
- events = yield self._store.get_events(referenced_events)
+ events = await self._store.get_events(referenced_events)
for event_id, event in iteritems(events):
if event.type != EventTypes.Message:
continue
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index bf0943f265..999c621b92 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -14,11 +14,9 @@
# limitations under the License.
import logging
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, Membership, RoomCreationPreset
from synapse.types import UserID, create_requester
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -51,8 +49,7 @@ class ServerNoticesManager(object):
"""
return self._config.server_notices_mxid is not None
- @defer.inlineCallbacks
- def send_notice(
+ async def send_notice(
self, user_id, event_content, type=EventTypes.Message, state_key=None
):
"""Send a notice to the given user
@@ -68,8 +65,8 @@ class ServerNoticesManager(object):
Returns:
Deferred[FrozenEvent]
"""
- room_id = yield self.get_or_create_notice_room_for_user(user_id)
- yield self.maybe_invite_user_to_room(user_id, room_id)
+ room_id = await self.get_or_create_notice_room_for_user(user_id)
+ await self.maybe_invite_user_to_room(user_id, room_id)
system_mxid = self._config.server_notices_mxid
requester = create_requester(system_mxid)
@@ -86,13 +83,13 @@ class ServerNoticesManager(object):
if state_key is not None:
event_dict["state_key"] = state_key
- res = yield self._event_creation_handler.create_and_send_nonmember_event(
+ res = await self._event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, ratelimit=False
)
return res
- @cachedInlineCallbacks()
- def get_or_create_notice_room_for_user(self, user_id):
+ @cached()
+ async def get_or_create_notice_room_for_user(self, user_id):
"""Get the room for notices for a given user
If we have not yet created a notice room for this user, create it, but don't
@@ -109,7 +106,7 @@ class ServerNoticesManager(object):
assert self._is_mine_id(user_id), "Cannot send server notices to remote users"
- rooms = yield self._store.get_rooms_for_local_user_where_membership_is(
+ rooms = await self._store.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE, Membership.JOIN]
)
for room in rooms:
@@ -118,7 +115,7 @@ class ServerNoticesManager(object):
# be joined. This is kinda deliberate, in that if somebody somehow
# manages to invite the system user to a room, that doesn't make it
# the server notices room.
- user_ids = yield self._store.get_users_in_room(room.room_id)
+ user_ids = await self._store.get_users_in_room(room.room_id)
if self.server_notices_mxid in user_ids:
# we found a room which our user shares with the system notice
# user
@@ -146,7 +143,7 @@ class ServerNoticesManager(object):
}
requester = create_requester(self.server_notices_mxid)
- info = yield self._room_creation_handler.create_room(
+ info = await self._room_creation_handler.create_room(
requester,
config={
"preset": RoomCreationPreset.PRIVATE_CHAT,
@@ -158,7 +155,7 @@ class ServerNoticesManager(object):
)
room_id = info["room_id"]
- max_id = yield self._store.add_tag_to_room(
+ max_id = await self._store.add_tag_to_room(
user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
)
self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
@@ -166,8 +163,7 @@ class ServerNoticesManager(object):
logger.info("Created server notices room %s for %s", room_id, user_id)
return room_id
- @defer.inlineCallbacks
- def maybe_invite_user_to_room(self, user_id: str, room_id: str):
+ async def maybe_invite_user_to_room(self, user_id: str, room_id: str):
"""Invite the given user to the given server room, unless the user has already
joined or been invited to it.
@@ -179,14 +175,14 @@ class ServerNoticesManager(object):
# Check whether the user has already joined or been invited to this room. If
# that's the case, there is no need to re-invite them.
- joined_rooms = yield self._store.get_rooms_for_local_user_where_membership_is(
+ joined_rooms = await self._store.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE, Membership.JOIN]
)
for room in joined_rooms:
if room.room_id == room_id:
return
- yield self._room_member_handler.update_membership(
+ await self._room_member_handler.update_membership(
requester=requester,
target=UserID.from_string(user_id),
room_id=room_id,
diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py
index 652bab58e3..be74e86641 100644
--- a/synapse/server_notices/server_notices_sender.py
+++ b/synapse/server_notices/server_notices_sender.py
@@ -12,8 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
from synapse.server_notices.consent_server_notices import ConsentServerNotices
from synapse.server_notices.resource_limits_server_notices import (
ResourceLimitsServerNotices,
@@ -36,18 +34,16 @@ class ServerNoticesSender(object):
ResourceLimitsServerNotices(hs),
)
- @defer.inlineCallbacks
- def on_user_syncing(self, user_id):
+ async def on_user_syncing(self, user_id):
"""Called when the user performs a sync operation.
Args:
user_id (str): mxid of user who synced
"""
for sn in self._server_notices:
- yield sn.maybe_send_server_notice_to_user(user_id)
+ await sn.maybe_send_server_notice_to_user(user_id)
- @defer.inlineCallbacks
- def on_user_ip(self, user_id):
+ async def on_user_ip(self, user_id):
"""Called on the master when a worker process saw a client request.
Args:
@@ -57,4 +53,4 @@ class ServerNoticesSender(object):
# we check for notices to send to the user in on_user_ip as well as
# in on_user_syncing
for sn in self._server_notices:
- yield sn.maybe_send_server_notice_to_user(user_id)
+ await sn.maybe_send_server_notice_to_user(user_id)
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index 3e53c8568a..efcdd2100b 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -273,8 +273,7 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="delete_account_validity_for_user",
)
- @defer.inlineCallbacks
- def is_server_admin(self, user):
+ async def is_server_admin(self, user):
"""Determines if a user is an admin of this homeserver.
Args:
@@ -283,7 +282,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns (bool):
true iff the user is a server admin, false otherwise.
"""
- res = yield self.db.simple_select_one_onecol(
+ res = await self.db.simple_select_one_onecol(
table="users",
keyvalues={"name": user.to_string()},
retcol="admin",
diff --git a/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql b/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql
index 163529c071..bbdde121e8 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql
+++ b/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql
@@ -35,9 +35,13 @@ DELETE FROM background_updates WHERE update_name IN (
'populate_stats_cleanup'
);
+-- this relies on current_state_events.membership having been populated, so add
+-- a dependency on current_state_events_membership.
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
- ('populate_stats_process_rooms', '{}', '');
+ ('populate_stats_process_rooms', '{}', 'current_state_events_membership');
+-- this also relies on current_state_events.membership having been populated, but
+-- we get that as a side-effect of depending on populate_stats_process_rooms.
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
('populate_stats_process_users', '{}', 'populate_stats_process_rooms');
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 9d851beaa5..86d04ea9ac 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -16,6 +16,11 @@
import contextlib
import threading
from collections import deque
+from typing import Dict, Set, Tuple
+
+from typing_extensions import Deque
+
+from synapse.storage.database import Database, LoggingTransaction
class IdGenerator(object):
@@ -87,7 +92,7 @@ class StreamIdGenerator(object):
self._current = (max if step > 0 else min)(
self._current, _load_current_id(db_conn, table, column, step)
)
- self._unfinished_ids = deque()
+ self._unfinished_ids = deque() # type: Deque[int]
def get_next(self):
"""
@@ -163,7 +168,7 @@ class ChainedIdGenerator(object):
self.chained_generator = chained_generator
self._lock = threading.Lock()
self._current_max = _load_current_id(db_conn, table, column)
- self._unfinished_ids = deque()
+ self._unfinished_ids = deque() # type: Deque[Tuple[int, int]]
def get_next(self):
"""
@@ -198,3 +203,163 @@ class ChainedIdGenerator(object):
return stream_id - 1, chained_id
return self._current_max, self.chained_generator.get_current_token()
+
+
+class MultiWriterIdGenerator:
+ """An ID generator that tracks a stream that can have multiple writers.
+
+ Uses a Postgres sequence to coordinate ID assignment, but positions of other
+ writers will only get updated when `advance` is called (by replication).
+
+ Note: Only works with Postgres.
+
+ Args:
+ db_conn
+ db
+ instance_name: The name of this instance.
+ table: Database table associated with stream.
+ instance_column: Column that stores the row's writer's instance name
+ id_column: Column that stores the stream ID.
+ sequence_name: The name of the postgres sequence used to generate new
+ IDs.
+ """
+
+ def __init__(
+ self,
+ db_conn,
+ db: Database,
+ instance_name: str,
+ table: str,
+ instance_column: str,
+ id_column: str,
+ sequence_name: str,
+ ):
+ self._db = db
+ self._instance_name = instance_name
+ self._sequence_name = sequence_name
+
+ # We lock as some functions may be called from DB threads.
+ self._lock = threading.Lock()
+
+ self._current_positions = self._load_current_ids(
+ db_conn, table, instance_column, id_column
+ )
+
+ # Set of local IDs that we're still processing. The current position
+ # should be less than the minimum of this set (if not empty).
+ self._unfinished_ids = set() # type: Set[int]
+
+ def _load_current_ids(
+ self, db_conn, table: str, instance_column: str, id_column: str
+ ) -> Dict[str, int]:
+ sql = """
+ SELECT %(instance)s, MAX(%(id)s) FROM %(table)s
+ GROUP BY %(instance)s
+ """ % {
+ "instance": instance_column,
+ "id": id_column,
+ "table": table,
+ }
+
+ cur = db_conn.cursor()
+ cur.execute(sql)
+
+ # `cur` is an iterable over returned rows, which are 2-tuples.
+ current_positions = dict(cur)
+
+ cur.close()
+
+ return current_positions
+
+ def _load_next_id_txn(self, txn):
+ txn.execute("SELECT nextval(?)", (self._sequence_name,))
+ (next_id,) = txn.fetchone()
+ return next_id
+
+ async def get_next(self):
+ """
+ Usage:
+ with await stream_id_gen.get_next() as stream_id:
+ # ... persist event ...
+ """
+ next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
+
+ # Assert the fetched ID is actually greater than what we currently
+ # believe the ID to be. If not, then the sequence and table have got
+ # out of sync somehow.
+ assert self.get_current_token() < next_id
+
+ with self._lock:
+ self._unfinished_ids.add(next_id)
+
+ @contextlib.contextmanager
+ def manager():
+ try:
+ yield next_id
+ finally:
+ self._mark_id_as_finished(next_id)
+
+ return manager()
+
+ def get_next_txn(self, txn: LoggingTransaction):
+ """
+ Usage:
+
+ stream_id = stream_id_gen.get_next(txn)
+ # ... persist event ...
+ """
+
+ next_id = self._load_next_id_txn(txn)
+
+ with self._lock:
+ self._unfinished_ids.add(next_id)
+
+ txn.call_after(self._mark_id_as_finished, next_id)
+ txn.call_on_exception(self._mark_id_as_finished, next_id)
+
+ return next_id
+
+ def _mark_id_as_finished(self, next_id: int):
+ """The ID has finished being processed so we should advance the
+ current poistion if possible.
+ """
+
+ with self._lock:
+ self._unfinished_ids.discard(next_id)
+
+ # Figure out if its safe to advance the position by checking there
+ # aren't any lower allocated IDs that are yet to finish.
+ if all(c > next_id for c in self._unfinished_ids):
+ curr = self._current_positions.get(self._instance_name, 0)
+ self._current_positions[self._instance_name] = max(curr, next_id)
+
+ def get_current_token(self, instance_name: str = None) -> int:
+ """Gets the current position of a named writer (defaults to current
+ instance).
+
+ Returns 0 if we don't have a position for the named writer (likely due
+ to it being a new writer).
+ """
+
+ if instance_name is None:
+ instance_name = self._instance_name
+
+ with self._lock:
+ return self._current_positions.get(instance_name, 0)
+
+ def get_positions(self) -> Dict[str, int]:
+ """Get a copy of the current positon map.
+ """
+
+ with self._lock:
+ return dict(self._current_positions)
+
+ def advance(self, instance_name: str, new_id: int):
+ """Advance the postion of the named writer to the given ID, if greater
+ than existing entry.
+ """
+
+ with self._lock:
+ self._current_positions[instance_name] = max(
+ new_id, self._current_positions.get(instance_name, 0)
+ )
|