diff --git a/synapse/__init__.py b/synapse/__init__.py
index 0116478fbb..5bb09a37d7 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -36,7 +36,7 @@ try:
except ImportError:
pass
-__version__ = "1.16.1"
+__version__ = "1.17.0rc1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 06ba6604f3..40dc62ef6c 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
from typing import Optional
@@ -22,7 +21,6 @@ from netaddr import IPAddress
from twisted.internet import defer
from twisted.web.server import Request
-import synapse.logging.opentracing as opentracing
import synapse.types
from synapse import event_auth
from synapse.api.auth_blocking import AuthBlocking
@@ -35,6 +33,7 @@ from synapse.api.errors import (
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
+from synapse.logging import opentracing as opentracing
from synapse.types import StateMap, UserID
from synapse.util.caches import register_cache
from synapse.util.caches.lrucache import LruCache
@@ -538,7 +537,7 @@ class Auth(object):
# Currently we ignore the `for_verification` flag even though there are
# some situations where we can drop particular auth events when adding
# to the event's `auth_events` (e.g. joins pointing to previous joins
- # when room is publically joinable). Dropping event IDs has the
+ # when room is publicly joinable). Dropping event IDs has the
# advantage that the auth chain for the room grows slower, but we use
# the auth chain in state resolution v2 to order events, which means
# care must be taken if dropping events to ensure that it doesn't
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 27a3fc9ed6..f6792d9fc8 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -21,7 +21,7 @@ from typing import Dict, Iterable, Optional, Set
from typing_extensions import ContextManager
-from twisted.internet import defer, reactor
+from twisted.internet import address, defer, reactor
import synapse
import synapse.events
@@ -206,10 +206,30 @@ class KeyUploadServlet(RestServlet):
if body:
# They're actually trying to upload something, proxy to main synapse.
- # Pass through the auth headers, if any, in case the access token
- # is there.
- auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", [])
- headers = {"Authorization": auth_headers}
+
+ # Proxy headers from the original request, such as the auth headers
+ # (in case the access token is there) and the original IP /
+ # User-Agent of the request.
+ headers = {
+ header: request.requestHeaders.getRawHeaders(header, [])
+ for header in (b"Authorization", b"User-Agent")
+ }
+ # Add the previous hop the the X-Forwarded-For header.
+ x_forwarded_for = request.requestHeaders.getRawHeaders(
+ b"X-Forwarded-For", []
+ )
+ if isinstance(request.client, (address.IPv4Address, address.IPv6Address)):
+ previous_host = request.client.host.encode("ascii")
+ # If the header exists, add to the comma-separated list of the first
+ # instance of the header. Otherwise, generate a new header.
+ if x_forwarded_for:
+ x_forwarded_for = [
+ x_forwarded_for[0] + b", " + previous_host
+ ] + x_forwarded_for[1:]
+ else:
+ x_forwarded_for = [previous_host]
+ headers[b"X-Forwarded-For"] = x_forwarded_for
+
try:
result = await self.http_client.post_json_get_json(
self.main_uri + request.uri.decode("ascii"), body, headers=headers
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index da9a5e86d4..f92bfb420b 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -98,7 +98,6 @@ class ApplicationServiceApi(SimpleHttpClient):
if service.url is None:
return False
uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
- response = None
try:
response = yield self.get_json(uri, {"access_token": service.hs_token})
if response is not None: # just an empty json object
diff --git a/synapse/config/__main__.py b/synapse/config/__main__.py
index fca35b008c..65043d5b5b 100644
--- a/synapse/config/__main__.py
+++ b/synapse/config/__main__.py
@@ -16,6 +16,7 @@ from synapse.config._base import ConfigError
if __name__ == "__main__":
import sys
+
from synapse.config.homeserver import HomeServerConfig
action = sys.argv[1]
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index ca61214454..b1dc7ad502 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from __future__ import print_function
# This file can't be called email.py because if it is, we cannot:
@@ -73,7 +72,7 @@ class EmailConfig(Config):
template_dir = email_config.get("template_dir")
# we need an absolute path, because we change directory after starting (and
- # we don't yet know what auxilliary templates like mail.css we will need).
+ # we don't yet know what auxiliary templates like mail.css we will need).
# (Note that loading as package_resources with jinja.PackageLoader doesn't
# work for the same reason.)
if not template_dir:
@@ -145,8 +144,8 @@ class EmailConfig(Config):
or self.threepid_behaviour_email == ThreepidBehaviour.LOCAL
):
# make sure we can import the required deps
- import jinja2
import bleach
+ import jinja2
# prevent unused warnings
jinja2
diff --git a/synapse/config/jwt_config.py b/synapse/config/jwt_config.py
index a568726985..fce96b4acf 100644
--- a/synapse/config/jwt_config.py
+++ b/synapse/config/jwt_config.py
@@ -45,10 +45,37 @@ class JWTConfig(Config):
def generate_config_section(self, **kwargs):
return """\
- # The JWT needs to contain a globally unique "sub" (subject) claim.
+ # JSON web token integration. The following settings can be used to make
+ # Synapse JSON web tokens for authentication, instead of its internal
+ # password database.
+ #
+ # Each JSON Web Token needs to contain a "sub" (subject) claim, which is
+ # used as the localpart of the mxid.
+ #
+ # Note that this is a non-standard login type and client support is
+ # expected to be non-existant.
+ #
+ # See https://github.com/matrix-org/synapse/blob/master/docs/jwt.md.
#
#jwt_config:
- # enabled: true
- # secret: "a secret"
- # algorithm: "HS256"
+ # Uncomment the following to enable authorization using JSON web
+ # tokens. Defaults to false.
+ #
+ #enabled: true
+
+ # This is either the private shared secret or the public key used to
+ # decode the contents of the JSON web token.
+ #
+ # Required if 'enabled' is true.
+ #
+ #secret: "provided-by-your-issuer"
+
+ # The algorithm used to sign the JSON web token.
+ #
+ # Supported algorithms are listed at
+ # https://pyjwt.readthedocs.io/en/latest/algorithms.html
+ #
+ # Required if 'enabled' is true.
+ #
+ #algorithm: "provided-by-your-issuer"
"""
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index a0c4a40c27..92aadfe7ef 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -162,7 +162,7 @@ class EventBuilderFactory(object):
def __init__(self, hs):
self.clock = hs.get_clock()
self.hostname = hs.hostname
- self.signing_key = hs.config.signing_key[0]
+ self.signing_key = hs.signing_key
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 687cd841ac..a37cc9cb4a 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -87,7 +87,7 @@ class FederationClient(FederationBase):
self.transport_layer = hs.get_federation_transport_client()
self.hostname = hs.hostname
- self.signing_key = hs.config.signing_key[0]
+ self.signing_key = hs.signing_key
self._get_pdu_cache = ExpiringCache(
cache_name="get_pdu_cache",
@@ -245,7 +245,7 @@ class FederationClient(FederationBase):
event_id: event to fetch
room_version: version of the room
outlier: Indicates whether the PDU is an `outlier`, i.e. if
- it's from an arbitary point in the context as opposed to part
+ it's from an arbitrary point in the context as opposed to part
of the current block of PDUs. Defaults to `False`
timeout: How long to try (in ms) each destination for before
moving to the next destination. None indicates no timeout.
@@ -351,7 +351,7 @@ class FederationClient(FederationBase):
outlier: bool = False,
include_none: bool = False,
) -> List[EventBase]:
- """Takes a list of PDUs and checks the signatures and hashs of each
+ """Takes a list of PDUs and checks the signatures and hashes of each
one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of
that PDU.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index e704cf2f44..86051decd4 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -717,7 +717,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
# server name is a literal IP
allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
if not isinstance(allow_ip_literals, bool):
- logger.warning("Ignorning non-bool allow_ip_literals flag")
+ logger.warning("Ignoring non-bool allow_ip_literals flag")
allow_ip_literals = True
if not allow_ip_literals:
# check for ipv6 literals. These start with '['.
@@ -731,7 +731,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
# next, check the deny list
deny = acl_event.content.get("deny", [])
if not isinstance(deny, (list, tuple)):
- logger.warning("Ignorning non-list deny ACL %s", deny)
+ logger.warning("Ignoring non-list deny ACL %s", deny)
deny = []
for e in deny:
if _acl_entry_matches(server_name, e):
@@ -741,7 +741,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
# then the allow list.
allow = acl_event.content.get("allow", [])
if not isinstance(allow, (list, tuple)):
- logger.warning("Ignorning non-list allow ACL %s", allow)
+ logger.warning("Ignoring non-list allow ACL %s", allow)
allow = []
for e in allow:
if _acl_entry_matches(server_name, e):
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 6bbd762681..860b03f7b9 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -359,7 +359,7 @@ class BaseFederationRow(object):
Specifies how to identify, serialize and deserialize the different types.
"""
- TypeId = "" # Unique string that ids the type. Must be overriden in sub classes.
+ TypeId = "" # Unique string that ids the type. Must be overridden in sub classes.
@staticmethod
def from_data(data):
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 4e698981a4..12966e239b 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -119,7 +119,7 @@ class PerDestinationQueue(object):
)
def send_pdu(self, pdu: EventBase, order: int) -> None:
- """Add a PDU to the queue, and start the transmission loop if neccessary
+ """Add a PDU to the queue, and start the transmission loop if necessary
Args:
pdu: pdu to send
@@ -129,7 +129,7 @@ class PerDestinationQueue(object):
self.attempt_new_transaction()
def send_presence(self, states: Iterable[UserPresenceState]) -> None:
- """Add presence updates to the queue. Start the transmission loop if neccessary.
+ """Add presence updates to the queue. Start the transmission loop if necessary.
Args:
states: presence to send
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 9f99311419..cfdf23d366 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -746,7 +746,7 @@ class TransportLayerClient(object):
def remove_user_from_group(
self, destination, group_id, requester_user_id, user_id, content
):
- """Remove a user fron a group
+ """Remove a user from a group
"""
path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index af4595498c..d1bac318e7 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -109,7 +109,7 @@ class Authenticator(object):
self.server_name = hs.hostname
self.store = hs.get_datastore()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
- self.notifer = hs.get_notifier()
+ self.notifier = hs.get_notifier()
self.replication_client = None
if hs.config.worker.worker_app:
@@ -175,7 +175,7 @@ class Authenticator(object):
await self.store.set_destination_retry_timings(origin, None, 0, 0)
# Inform the relevant places that the remote server is back up.
- self.notifer.notify_remote_server_up(origin)
+ self.notifier.notify_remote_server_up(origin)
if self.replication_client:
# If we're on a worker we try and inform master about this. The
# replication client doesn't hook into the notifier to avoid
@@ -361,11 +361,7 @@ class BaseFederationServlet(object):
continue
server.register_paths(
- method,
- (pattern,),
- self._wrap(code),
- self.__class__.__name__,
- trace=False,
+ method, (pattern,), self._wrap(code), self.__class__.__name__,
)
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index 27b0c02655..dab13c243f 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -70,7 +70,7 @@ class GroupAttestationSigning(object):
self.keyring = hs.get_keyring()
self.clock = hs.get_clock()
self.server_name = hs.hostname
- self.signing_key = hs.config.signing_key[0]
+ self.signing_key = hs.signing_key
@defer.inlineCallbacks
def verify_attestation(self, attestation, group_id, user_id, server_name=None):
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index 8db8ab1b7b..8cb922ddc7 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -41,7 +41,7 @@ class GroupsServerWorkerHandler(object):
self.clock = hs.get_clock()
self.keyring = hs.get_keyring()
self.is_mine_id = hs.is_mine_id
- self.signing_key = hs.config.signing_key[0]
+ self.signing_key = hs.signing_key
self.server_name = hs.hostname
self.attestations = hs.get_groups_attestation_signing()
self.transport_client = hs.get_federation_transport_client()
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 904c96eeec..92d4c6e16c 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -48,8 +48,7 @@ class ApplicationServicesHandler(object):
self.current_max = 0
self.is_processing = False
- @defer.inlineCallbacks
- def notify_interested_services(self, current_id):
+ async def notify_interested_services(self, current_id):
"""Notifies (pushes) all application services interested in this event.
Pushing is done asynchronously, so this method won't block for any
@@ -74,7 +73,7 @@ class ApplicationServicesHandler(object):
(
upper_bound,
events,
- ) = yield self.store.get_new_events_for_appservice(
+ ) = await self.store.get_new_events_for_appservice(
self.current_max, limit
)
@@ -85,10 +84,9 @@ class ApplicationServicesHandler(object):
for event in events:
events_by_room.setdefault(event.room_id, []).append(event)
- @defer.inlineCallbacks
- def handle_event(event):
+ async def handle_event(event):
# Gather interested services
- services = yield self._get_services_for_event(event)
+ services = await self._get_services_for_event(event)
if len(services) == 0:
return # no services need notifying
@@ -96,9 +94,9 @@ class ApplicationServicesHandler(object):
# query API for all services which match that user regex.
# This needs to block as these user queries need to be
# made BEFORE pushing the event.
- yield self._check_user_exists(event.sender)
+ await self._check_user_exists(event.sender)
if event.type == EventTypes.Member:
- yield self._check_user_exists(event.state_key)
+ await self._check_user_exists(event.state_key)
if not self.started_scheduler:
@@ -115,17 +113,16 @@ class ApplicationServicesHandler(object):
self.scheduler.submit_event_for_as(service, event)
now = self.clock.time_msec()
- ts = yield self.store.get_received_ts(event.event_id)
+ ts = await self.store.get_received_ts(event.event_id)
synapse.metrics.event_processing_lag_by_event.labels(
"appservice_sender"
).observe((now - ts) / 1000)
- @defer.inlineCallbacks
- def handle_room_events(events):
+ async def handle_room_events(events):
for event in events:
- yield handle_event(event)
+ await handle_event(event)
- yield make_deferred_yieldable(
+ await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(handle_room_events, evs)
@@ -135,10 +132,10 @@ class ApplicationServicesHandler(object):
)
)
- yield self.store.set_appservice_last_pos(upper_bound)
+ await self.store.set_appservice_last_pos(upper_bound)
now = self.clock.time_msec()
- ts = yield self.store.get_received_ts(events[-1].event_id)
+ ts = await self.store.get_received_ts(events[-1].event_id)
synapse.metrics.event_processing_positions.labels(
"appservice_sender"
@@ -161,8 +158,7 @@ class ApplicationServicesHandler(object):
finally:
self.is_processing = False
- @defer.inlineCallbacks
- def query_user_exists(self, user_id):
+ async def query_user_exists(self, user_id):
"""Check if any application service knows this user_id exists.
Args:
@@ -170,15 +166,14 @@ class ApplicationServicesHandler(object):
Returns:
True if this user exists on at least one application service.
"""
- user_query_services = yield self._get_services_for_user(user_id=user_id)
+ user_query_services = self._get_services_for_user(user_id=user_id)
for user_service in user_query_services:
- is_known_user = yield self.appservice_api.query_user(user_service, user_id)
+ is_known_user = await self.appservice_api.query_user(user_service, user_id)
if is_known_user:
return True
return False
- @defer.inlineCallbacks
- def query_room_alias_exists(self, room_alias):
+ async def query_room_alias_exists(self, room_alias):
"""Check if an application service knows this room alias exists.
Args:
@@ -193,19 +188,18 @@ class ApplicationServicesHandler(object):
s for s in services if (s.is_interested_in_alias(room_alias_str))
]
for alias_service in alias_query_services:
- is_known_alias = yield self.appservice_api.query_alias(
+ is_known_alias = await self.appservice_api.query_alias(
alias_service, room_alias_str
)
if is_known_alias:
# the alias exists now so don't query more ASes.
- result = yield self.store.get_association_from_room_alias(room_alias)
+ result = await self.store.get_association_from_room_alias(room_alias)
return result
- @defer.inlineCallbacks
- def query_3pe(self, kind, protocol, fields):
- services = yield self._get_services_for_3pn(protocol)
+ async def query_3pe(self, kind, protocol, fields):
+ services = self._get_services_for_3pn(protocol)
- results = yield make_deferred_yieldable(
+ results = await make_deferred_yieldable(
defer.DeferredList(
[
run_in_background(
@@ -224,8 +218,7 @@ class ApplicationServicesHandler(object):
return ret
- @defer.inlineCallbacks
- def get_3pe_protocols(self, only_protocol=None):
+ async def get_3pe_protocols(self, only_protocol=None):
services = self.store.get_app_services()
protocols = {}
@@ -238,7 +231,7 @@ class ApplicationServicesHandler(object):
if p not in protocols:
protocols[p] = []
- info = yield self.appservice_api.get_3pe_protocol(s, p)
+ info = await self.appservice_api.get_3pe_protocol(s, p)
if info is not None:
protocols[p].append(info)
@@ -263,8 +256,7 @@ class ApplicationServicesHandler(object):
return protocols
- @defer.inlineCallbacks
- def _get_services_for_event(self, event):
+ async def _get_services_for_event(self, event):
"""Retrieve a list of application services interested in this event.
Args:
@@ -280,7 +272,7 @@ class ApplicationServicesHandler(object):
# inside of a list comprehension anymore.
interested_list = []
for s in services:
- if (yield s.is_interested(event, self.store)):
+ if await s.is_interested(event, self.store):
interested_list.append(s)
return interested_list
@@ -288,21 +280,20 @@ class ApplicationServicesHandler(object):
def _get_services_for_user(self, user_id):
services = self.store.get_app_services()
interested_list = [s for s in services if (s.is_interested_in_user(user_id))]
- return defer.succeed(interested_list)
+ return interested_list
def _get_services_for_3pn(self, protocol):
services = self.store.get_app_services()
interested_list = [s for s in services if s.is_interested_in_protocol(protocol)]
- return defer.succeed(interested_list)
+ return interested_list
- @defer.inlineCallbacks
- def _is_unknown_user(self, user_id):
+ async def _is_unknown_user(self, user_id):
if not self.is_mine_id(user_id):
# we don't know if they are unknown or not since it isn't one of our
# users. We can't poke ASes.
return False
- user_info = yield self.store.get_user_by_id(user_id)
+ user_info = await self.store.get_user_by_id(user_id)
if user_info:
return False
@@ -311,10 +302,9 @@ class ApplicationServicesHandler(object):
service_list = [s for s in services if s.sender == user_id]
return len(service_list) == 0
- @defer.inlineCallbacks
- def _check_user_exists(self, user_id):
- unknown_user = yield self._is_unknown_user(user_id)
+ async def _check_user_exists(self, user_id):
+ unknown_user = await self._is_unknown_user(user_id)
if unknown_user:
- exists = yield self.query_user_exists(user_id)
+ exists = await self.query_user_exists(user_id)
return exists
return True
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index c3f86e7414..a162392e4c 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
import time
import unicodedata
@@ -24,7 +23,6 @@ import attr
import bcrypt # type: ignore[import]
import pymacaroons
-import synapse.util.stringutils as stringutils
from synapse.api.constants import LoginType
from synapse.api.errors import (
AuthError,
@@ -45,6 +43,8 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
from synapse.push.mailer import load_jinja2_templates
from synapse.types import Requester, UserID
+from synapse.util import stringutils as stringutils
+from synapse.util.threepids import canonicalise_email
from ._base import BaseHandler
@@ -928,7 +928,7 @@ class AuthHandler(BaseHandler):
# for the presence of an email address during password reset was
# case sensitive).
if medium == "email":
- address = address.lower()
+ address = canonicalise_email(address)
await self.store.user_add_threepid(
user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
@@ -956,7 +956,7 @@ class AuthHandler(BaseHandler):
# 'Canonicalise' email addresses as per above
if medium == "email":
- address = address.lower()
+ address = canonicalise_email(address)
identity_handler = self.hs.get_handlers().identity_handler
result = await identity_handler.try_unbind_threepid(
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 76f213723a..d79ffefdb5 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -12,11 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
import urllib
-import xml.etree.ElementTree as ET
from typing import Dict, Optional, Tuple
+from xml.etree import ElementTree as ET
from twisted.web.client import PartialDownloadError
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 4dbd8e1d98..ca7da42a3f 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -19,8 +19,9 @@
import itertools
import logging
+from collections import Container
from http import HTTPStatus
-from typing import Dict, Iterable, List, Optional, Sequence, Tuple
+from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
import attr
from signedjson.key import decode_verify_key_bytes
@@ -742,6 +743,9 @@ class FederationHandler(BaseHandler):
# device and recognize the algorithm then we can work out the
# exact key to expect. Otherwise check it matches any key we
# have for that device.
+
+ current_keys = [] # type: Container[str]
+
if device:
keys = device.get("keys", {}).get("keys", {})
@@ -758,15 +762,15 @@ class FederationHandler(BaseHandler):
current_keys = keys.values()
elif device_id:
# We don't have any keys for the device ID.
- current_keys = []
+ pass
else:
# The event didn't include a device ID, so we just look for
# keys across all devices.
- current_keys = (
+ current_keys = [
key
for device in cached_devices
for key in device.get("keys", {}).get("keys", {}).values()
- )
+ ]
# We now check that the sender key matches (one of) the expected
# keys.
@@ -1011,7 +1015,7 @@ class FederationHandler(BaseHandler):
if e_type == EventTypes.Member and event.membership == Membership.JOIN
]
- joined_domains = {}
+ joined_domains = {} # type: Dict[str, int]
for u, d in joined_users:
try:
dom = get_domain_from_id(u)
@@ -1277,14 +1281,15 @@ class FederationHandler(BaseHandler):
try:
# Try the host we successfully got a response to /make_join/
# request first.
+ host_list = list(target_hosts)
try:
- target_hosts.remove(origin)
- target_hosts.insert(0, origin)
+ host_list.remove(origin)
+ host_list.insert(0, origin)
except ValueError:
pass
ret = await self.federation_client.send_join(
- target_hosts, event, room_version_obj
+ host_list, event, room_version_obj
)
origin = ret["origin"]
@@ -1562,7 +1567,7 @@ class FederationHandler(BaseHandler):
room_version,
event.get_pdu_json(),
self.hs.hostname,
- self.hs.config.signing_key[0],
+ self.hs.signing_key,
)
)
@@ -1584,13 +1589,14 @@ class FederationHandler(BaseHandler):
# Try the host that we succesfully called /make_leave/ on first for
# the /send_leave/ request.
+ host_list = list(target_hosts)
try:
- target_hosts.remove(origin)
- target_hosts.insert(0, origin)
+ host_list.remove(origin)
+ host_list.insert(0, origin)
except ValueError:
pass
- await self.federation_client.send_leave(target_hosts, event)
+ await self.federation_client.send_leave(host_list, event)
context = await self.state_handler.compute_event_context(event)
stream_id = await self.persist_events_and_notify([(event, context)])
@@ -1604,7 +1610,7 @@ class FederationHandler(BaseHandler):
user_id: str,
membership: str,
content: JsonDict = {},
- params: Optional[Dict[str, str]] = None,
+ params: Optional[Dict[str, Union[str, Iterable[str]]]] = None,
) -> Tuple[str, EventBase, RoomVersion]:
(
origin,
@@ -2018,8 +2024,8 @@ class FederationHandler(BaseHandler):
auth_events_ids = await self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
- auth_events = await self.store.get_events(auth_events_ids)
- auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
+ auth_events_x = await self.store.get_events(auth_events_ids)
+ auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
# This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events.
@@ -2055,76 +2061,67 @@ class FederationHandler(BaseHandler):
# For new (non-backfilled and non-outlier) events we check if the event
# passes auth based on the current state. If it doesn't then we
# "soft-fail" the event.
- do_soft_fail_check = not backfilled and not event.internal_metadata.is_outlier()
- if do_soft_fail_check:
- extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
-
- extrem_ids = set(extrem_ids)
- prev_event_ids = set(event.prev_event_ids())
-
- if extrem_ids == prev_event_ids:
- # If they're the same then the current state is the same as the
- # state at the event, so no point rechecking auth for soft fail.
- do_soft_fail_check = False
-
- if do_soft_fail_check:
- room_version = await self.store.get_room_version_id(event.room_id)
- room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
-
- # Calculate the "current state".
- if state is not None:
- # If we're explicitly given the state then we won't have all the
- # prev events, and so we have a gap in the graph. In this case
- # we want to be a little careful as we might have been down for
- # a while and have an incorrect view of the current state,
- # however we still want to do checks as gaps are easy to
- # maliciously manufacture.
- #
- # So we use a "current state" that is actually a state
- # resolution across the current forward extremities and the
- # given state at the event. This should correctly handle cases
- # like bans, especially with state res v2.
+ if backfilled or event.internal_metadata.is_outlier():
+ return
- state_sets = await self.state_store.get_state_groups(
- event.room_id, extrem_ids
- )
- state_sets = list(state_sets.values())
- state_sets.append(state)
- current_state_ids = await self.state_handler.resolve_events(
- room_version, state_sets, event
- )
- current_state_ids = {
- k: e.event_id for k, e in current_state_ids.items()
- }
- else:
- current_state_ids = await self.state_handler.get_current_state_ids(
- event.room_id, latest_event_ids=extrem_ids
- )
+ extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
+ extrem_ids = set(extrem_ids)
+ prev_event_ids = set(event.prev_event_ids())
- logger.debug(
- "Doing soft-fail check for %s: state %s",
- event.event_id,
- current_state_ids,
+ if extrem_ids == prev_event_ids:
+ # If they're the same then the current state is the same as the
+ # state at the event, so no point rechecking auth for soft fail.
+ return
+
+ room_version = await self.store.get_room_version_id(event.room_id)
+ room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+
+ # Calculate the "current state".
+ if state is not None:
+ # If we're explicitly given the state then we won't have all the
+ # prev events, and so we have a gap in the graph. In this case
+ # we want to be a little careful as we might have been down for
+ # a while and have an incorrect view of the current state,
+ # however we still want to do checks as gaps are easy to
+ # maliciously manufacture.
+ #
+ # So we use a "current state" that is actually a state
+ # resolution across the current forward extremities and the
+ # given state at the event. This should correctly handle cases
+ # like bans, especially with state res v2.
+
+ state_sets = await self.state_store.get_state_groups(
+ event.room_id, extrem_ids
+ )
+ state_sets = list(state_sets.values())
+ state_sets.append(state)
+ current_state_ids = await self.state_handler.resolve_events(
+ room_version, state_sets, event
+ )
+ current_state_ids = {k: e.event_id for k, e in current_state_ids.items()}
+ else:
+ current_state_ids = await self.state_handler.get_current_state_ids(
+ event.room_id, latest_event_ids=extrem_ids
)
- # Now check if event pass auth against said current state
- auth_types = auth_types_for_event(event)
- current_state_ids = [
- e for k, e in current_state_ids.items() if k in auth_types
- ]
+ logger.debug(
+ "Doing soft-fail check for %s: state %s", event.event_id, current_state_ids,
+ )
- current_auth_events = await self.store.get_events(current_state_ids)
- current_auth_events = {
- (e.type, e.state_key): e for e in current_auth_events.values()
- }
+ # Now check if event pass auth against said current state
+ auth_types = auth_types_for_event(event)
+ current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
- try:
- event_auth.check(
- room_version_obj, event, auth_events=current_auth_events
- )
- except AuthError as e:
- logger.warning("Soft-failing %r because %s", event, e)
- event.internal_metadata.soft_failed = True
+ current_auth_events = await self.store.get_events(current_state_ids)
+ current_auth_events = {
+ (e.type, e.state_key): e for e in current_auth_events.values()
+ }
+
+ try:
+ event_auth.check(room_version_obj, event, auth_events=current_auth_events)
+ except AuthError as e:
+ logger.warning("Soft-failing %r because %s", event, e)
+ event.internal_metadata.soft_failed = True
async def on_query_auth(
self, origin, event_id, room_id, remote_auth_chain, rejects, missing
@@ -2293,10 +2290,10 @@ class FederationHandler(BaseHandler):
remote_auth_chain = await self.federation_client.get_event_auth(
origin, event.room_id, event.event_id
)
- except RequestSendFailed as e:
+ except RequestSendFailed as e1:
# The other side isn't around or doesn't implement the
# endpoint, so lets just bail out.
- logger.info("Failed to get event auth from remote: %s", e)
+ logger.info("Failed to get event auth from remote: %s", e1)
return context
seen_remotes = await self.store.have_seen_events(
@@ -2774,7 +2771,8 @@ class FederationHandler(BaseHandler):
logger.debug("Checking auth on event %r", event.content)
- last_exception = None
+ last_exception = None # type: Optional[Exception]
+
# for each public key in the 3pid invite event
for public_key_object in self.hs.get_auth().get_public_keys(invite_event):
try:
@@ -2828,6 +2826,12 @@ class FederationHandler(BaseHandler):
return
except Exception as e:
last_exception = e
+
+ if last_exception is None:
+ # we can only get here if get_public_keys() returned an empty list
+ # TODO: make this better
+ raise RuntimeError("no public key in invite event")
+
raise last_exception
async def _check_key_revocation(self, public_key, url):
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 7cb106e365..ecdb12a7bf 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -70,7 +70,7 @@ class GroupsLocalWorkerHandler(object):
self.clock = hs.get_clock()
self.keyring = hs.get_keyring()
self.is_mine_id = hs.is_mine_id
- self.signing_key = hs.config.signing_key[0]
+ self.signing_key = hs.signing_key
self.server_name = hs.hostname
self.notifier = hs.get_notifier()
self.attestations = hs.get_groups_attestation_signing()
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 4ba0042768..701233ebb4 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -251,10 +251,10 @@ class IdentityHandler(BaseHandler):
# 'browser-like' HTTPS.
auth_headers = self.federation_http_client.build_auth_headers(
destination=None,
- method="POST",
+ method=b"POST",
url_bytes=url_bytes,
content=content,
- destination_is=id_server,
+ destination_is=id_server.encode("ascii"),
)
headers = {b"Authorization": auth_headers}
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 665ad19b5d..da206e1ec1 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Optional, Tuple
+from typing import TYPE_CHECKING, Optional, Tuple
from canonicaljson import encode_canonical_json, json
@@ -55,6 +55,9 @@ from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -349,7 +352,7 @@ _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY = 7 * 24 * 60 * 60 * 1000
class EventCreationHandler(object):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -814,11 +817,17 @@ class EventCreationHandler(object):
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
- try:
- await self.auth.check_from_context(room_version, event, context)
- except AuthError as err:
- logger.warning("Denying new event %r because %s", event, err)
- raise err
+ if event.internal_metadata.is_out_of_band_membership():
+ # the only sort of out-of-band-membership events we expect to see here
+ # are invite rejections we have generated ourselves.
+ assert event.type == EventTypes.Member
+ assert event.content["membership"] == Membership.LEAVE
+ else:
+ try:
+ await self.auth.check_from_context(room_version, event, context)
+ except AuthError as err:
+ logger.warning("Denying new event %r because %s", event, err)
+ raise err
# Ensure that we can round trip before trying to persist in db
try:
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 27c479da9e..a1a8fa1d3b 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1,7 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2016-2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,17 +16,21 @@
import abc
import logging
from http import HTTPStatus
-from typing import Dict, Iterable, List, Optional, Tuple
+from typing import Dict, Iterable, List, Optional, Tuple, Union
+
+from unpaddedbase64 import encode_base64
from synapse import types
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.room_versions import EventFormatVersions
+from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.events import EventBase
+from synapse.events.builder import create_local_event_from_event_dict
from synapse.events.snapshot import EventContext
-from synapse.replication.http.membership import (
- ReplicationLocallyRejectInviteRestServlet,
-)
-from synapse.types import Collection, Requester, RoomAlias, RoomID, UserID
+from synapse.events.validator import EventValidator
+from synapse.storage.roommember import RoomsForUser
+from synapse.types import Collection, JsonDict, Requester, RoomAlias, RoomID, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room
@@ -74,10 +76,6 @@ class RoomMemberHandler(object):
)
if self._is_on_event_persistence_instance:
self.persist_event_storage = hs.get_storage().persistence
- else:
- self._locally_reject_client = ReplicationLocallyRejectInviteRestServlet.make_client(
- hs
- )
# This is only used to get at ratelimit function, and
# maybe_kick_guest_users. It's fine there are multiple of these as
@@ -105,46 +103,28 @@ class RoomMemberHandler(object):
raise NotImplementedError()
@abc.abstractmethod
- async def _remote_reject_invite(
+ async def remote_reject_invite(
self,
+ invite_event_id: str,
+ txn_id: Optional[str],
requester: Requester,
- remote_room_hosts: List[str],
- room_id: str,
- target: UserID,
- content: dict,
- ) -> Tuple[Optional[str], int]:
- """Attempt to reject an invite for a room this server is not in. If we
- fail to do so we locally mark the invite as rejected.
+ content: JsonDict,
+ ) -> Tuple[str, int]:
+ """
+ Rejects an out-of-band invite we have received from a remote server
Args:
- requester
- remote_room_hosts: List of servers to use to try and reject invite
- room_id
- target: The user rejecting the invite
- content: The content for the rejection event
+ invite_event_id: ID of the invite to be rejected
+ txn_id: optional transaction ID supplied by the client
+ requester: user making the rejection request, according to the access token
+ content: additional content to include in the rejection event.
+ Normally an empty dict.
Returns:
- A dictionary to be returned to the client, may
- include event_id etc, or nothing if we locally rejected
+ event id, stream_id of the leave event
"""
raise NotImplementedError()
- async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
- """Mark the invite has having been rejected even though we failed to
- create a leave event for it.
- """
- if self._is_on_event_persistence_instance:
- return await self.persist_event_storage.locally_reject_invite(
- user_id, room_id
- )
- else:
- result = await self._locally_reject_client(
- instance_name=self._event_stream_writer_instance,
- user_id=user_id,
- room_id=room_id,
- )
- return result["stream_id"]
-
@abc.abstractmethod
async def _user_joined_room(self, target: UserID, room_id: str) -> None:
"""Notifies distributor on master process that the user has joined the
@@ -288,7 +268,7 @@ class RoomMemberHandler(object):
ratelimit: bool = True,
content: Optional[dict] = None,
require_consent: bool = True,
- ) -> Tuple[Optional[str], int]:
+ ) -> Tuple[str, int]:
key = (room_id,)
with (await self.member_linearizer.queue(key)):
@@ -319,7 +299,7 @@ class RoomMemberHandler(object):
ratelimit: bool = True,
content: Optional[dict] = None,
require_consent: bool = True,
- ) -> Tuple[Optional[str], int]:
+ ) -> Tuple[str, int]:
content_specified = bool(content)
if content is None:
content = {}
@@ -485,11 +465,17 @@ class RoomMemberHandler(object):
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
# perhaps we've been invited
- inviter = await self._get_inviter(target.to_string(), room_id)
- if not inviter:
+ invite = await self.store.get_invite_for_local_user_in_room(
+ user_id=target.to_string(), room_id=room_id
+ ) # type: Optional[RoomsForUser]
+ if not invite:
raise SynapseError(404, "Not a known room")
- if self.hs.is_mine(inviter):
+ logger.info(
+ "%s rejects invite to %s from %s", target, room_id, invite.sender
+ )
+
+ if self.hs.is_mine_id(invite.sender):
# the inviter was on our server, but has now left. Carry on
# with the normal rejection codepath.
#
@@ -497,10 +483,10 @@ class RoomMemberHandler(object):
# active on other servers.
pass
else:
- # send the rejection to the inviter's HS.
- remote_room_hosts = remote_room_hosts + [inviter.domain]
- return await self._remote_reject_invite(
- requester, remote_room_hosts, room_id, target, content,
+ # send the rejection to the inviter's HS (with fallback to
+ # local event)
+ return await self.remote_reject_invite(
+ invite.event_id, txn_id, requester, content,
)
return await self._local_membership_update(
@@ -1014,33 +1000,119 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return event_id, stream_id
- async def _remote_reject_invite(
+ async def remote_reject_invite(
self,
+ invite_event_id: str,
+ txn_id: Optional[str],
requester: Requester,
- remote_room_hosts: List[str],
- room_id: str,
- target: UserID,
- content: dict,
- ) -> Tuple[Optional[str], int]:
- """Implements RoomMemberHandler._remote_reject_invite
+ content: JsonDict,
+ ) -> Tuple[str, int]:
+ """
+ Rejects an out-of-band invite received from a remote user
+
+ Implements RoomMemberHandler.remote_reject_invite
"""
+ invite_event = await self.store.get_event(invite_event_id)
+ room_id = invite_event.room_id
+ target_user = invite_event.state_key
+
+ # first of all, try doing a rejection via the inviting server
fed_handler = self.federation_handler
try:
+ inviter_id = UserID.from_string(invite_event.sender)
event, stream_id = await fed_handler.do_remotely_reject_invite(
- remote_room_hosts, room_id, target.to_string(), content=content,
+ [inviter_id.domain], room_id, target_user, content=content
)
return event.event_id, stream_id
except Exception as e:
- # if we were unable to reject the exception, just mark
- # it as rejected on our end and plough ahead.
+ # if we were unable to reject the invite, we will generate our own
+ # leave event.
#
# The 'except' clause is very broad, but we need to
# capture everything from DNS failures upwards
#
logger.warning("Failed to reject invite: %s", e)
- stream_id = await self.locally_reject_invite(target.to_string(), room_id)
- return None, stream_id
+ return await self._locally_reject_invite(
+ invite_event, txn_id, requester, content
+ )
+
+ async def _locally_reject_invite(
+ self,
+ invite_event: EventBase,
+ txn_id: Optional[str],
+ requester: Requester,
+ content: JsonDict,
+ ) -> Tuple[str, int]:
+ """Generate a local invite rejection
+
+ This is called after we fail to reject an invite via a remote server. It
+ generates an out-of-band membership event locally.
+
+ Args:
+ invite_event: the invite to be rejected
+ txn_id: optional transaction ID supplied by the client
+ requester: user making the rejection request, according to the access token
+ content: additional content to include in the rejection event.
+ Normally an empty dict.
+ """
+
+ room_id = invite_event.room_id
+ target_user = invite_event.state_key
+ room_version = await self.store.get_room_version(room_id)
+
+ content["membership"] = Membership.LEAVE
+
+ # the auth events for the new event are the same as that of the invite, plus
+ # the invite itself.
+ #
+ # the prev_events are just the invite.
+ invite_hash = invite_event.event_id # type: Union[str, Tuple]
+ if room_version.event_format == EventFormatVersions.V1:
+ alg, h = compute_event_reference_hash(invite_event)
+ invite_hash = (invite_event.event_id, {alg: encode_base64(h)})
+
+ auth_events = tuple(invite_event.auth_events) + (invite_hash,)
+ prev_events = (invite_hash,)
+
+ # we cap depth of generated events, to ensure that they are not
+ # rejected by other servers (and so that they can be persisted in
+ # the db)
+ depth = min(invite_event.depth + 1, MAX_DEPTH)
+
+ event_dict = {
+ "depth": depth,
+ "auth_events": auth_events,
+ "prev_events": prev_events,
+ "type": EventTypes.Member,
+ "room_id": room_id,
+ "sender": target_user,
+ "content": content,
+ "state_key": target_user,
+ }
+
+ event = create_local_event_from_event_dict(
+ clock=self.clock,
+ hostname=self.hs.hostname,
+ signing_key=self.hs.signing_key,
+ room_version=room_version,
+ event_dict=event_dict,
+ )
+ event.internal_metadata.outlier = True
+ event.internal_metadata.out_of_band_membership = True
+ if txn_id is not None:
+ event.internal_metadata.txn_id = txn_id
+ if requester.access_token_id is not None:
+ event.internal_metadata.token_id = requester.access_token_id
+
+ EventValidator().validate_new(event, self.config)
+
+ context = await self.state_handler.compute_event_context(event)
+ context.app_service = requester.app_service
+ stream_id = await self.event_creation_handler.handle_new_client_event(
+ requester, event, context, extra_users=[UserID.from_string(target_user)],
+ )
+ return event.event_id, stream_id
async def _user_joined_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_joined_room
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index 02e0c4103d..897338fd54 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -61,21 +61,22 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
return ret["event_id"], ret["stream_id"]
- async def _remote_reject_invite(
+ async def remote_reject_invite(
self,
+ invite_event_id: str,
+ txn_id: Optional[str],
requester: Requester,
- remote_room_hosts: List[str],
- room_id: str,
- target: UserID,
content: dict,
- ) -> Tuple[Optional[str], int]:
- """Implements RoomMemberHandler._remote_reject_invite
+ ) -> Tuple[str, int]:
+ """
+ Rejects an out-of-band invite received from a remote user
+
+ Implements RoomMemberHandler.remote_reject_invite
"""
ret = await self._remote_reject_client(
+ invite_event_id=invite_event_id,
+ txn_id=txn_id,
requester=requester,
- remote_room_hosts=remote_room_hosts,
- room_id=room_id,
- user_id=target.to_string(),
content=content,
)
return ret["event_id"], ret["stream_id"]
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 6c7abaa578..879c4c07c6 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -294,6 +294,9 @@ class TypingHandler(object):
rows.sort()
limited = False
+ # We, unusually, use a strict limit here as we have all the rows in
+ # memory rather than pulling them out of the database with a `LIMIT ?`
+ # clause.
if len(rows) > limit:
rows = rows[:limit]
current_id = rows[-1][0]
diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py
index 096619a8c2..479746c9c5 100644
--- a/synapse/http/additional_resource.py
+++ b/synapse/http/additional_resource.py
@@ -13,13 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.web.resource import Resource
-from twisted.web.server import NOT_DONE_YET
+from synapse.http.server import DirectServeJsonResource
-from synapse.http.server import wrap_json_request_handler
-
-class AdditionalResource(Resource):
+class AdditionalResource(DirectServeJsonResource):
"""Resource wrapper for additional_resources
If the user has configured additional_resources, we need to wrap the
@@ -41,16 +38,10 @@ class AdditionalResource(Resource):
handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
function to be called to handle the request.
"""
- Resource.__init__(self)
+ super().__init__()
self._handler = handler
- # required by the request_handler wrapper
- self.clock = hs.get_clock()
-
- def render(self, request):
- self._async_render(request)
- return NOT_DONE_YET
-
- @wrap_json_request_handler
def _async_render(self, request):
+ # Cheekily pass the result straight through, so we don't need to worry
+ # if its an awaitable or not.
return self._handler(request)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 18f6a8fd29..148eeb19dc 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -176,7 +176,7 @@ class MatrixFederationHttpClient(object):
def __init__(self, hs, tls_client_options_factory):
self.hs = hs
- self.signing_key = hs.config.signing_key[0]
+ self.signing_key = hs.signing_key
self.server_name = hs.hostname
real_reactor = hs.get_reactor()
@@ -562,13 +562,17 @@ class MatrixFederationHttpClient(object):
Returns:
list[bytes]: a list of headers to be added as "Authorization:" headers
"""
- request = {"method": method, "uri": url_bytes, "origin": self.server_name}
+ request = {
+ "method": method.decode("ascii"),
+ "uri": url_bytes.decode("ascii"),
+ "origin": self.server_name,
+ }
if destination is not None:
- request["destination"] = destination
+ request["destination"] = destination.decode("ascii")
if destination_is is not None:
- request["destination_is"] = destination_is
+ request["destination_is"] = destination_is.decode("ascii")
if content is not None:
request["content"] = content
diff --git a/synapse/http/server.py b/synapse/http/server.py
index d192de7923..2b35f86066 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import collections
import html
import logging
@@ -21,7 +22,7 @@ import types
import urllib
from http import HTTPStatus
from io import BytesIO
-from typing import Awaitable, Callable, TypeVar, Union
+from typing import Any, Callable, Dict, Tuple, Union
import jinja2
from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
@@ -62,99 +63,43 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
"""
-def wrap_json_request_handler(h):
- """Wraps a request handler method with exception handling.
-
- Also does the wrapping with request.processing as per wrap_async_request_handler.
-
- The handler method must have a signature of "handle_foo(self, request)",
- where "request" must be a SynapseRequest.
-
- The handler must return a deferred or a coroutine. If the deferred succeeds
- we assume that a response has been sent. If the deferred fails with a SynapseError we use
- it to send a JSON response with the appropriate HTTP reponse code. If the
- deferred fails with any other type of error we send a 500 reponse.
+def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
+ """Sends a JSON error response to clients.
"""
- async def wrapped_request_handler(self, request):
- try:
- await h(self, request)
- except SynapseError as e:
- code = e.code
- logger.info("%s SynapseError: %s - %s", request, code, e.msg)
-
- # Only respond with an error response if we haven't already started
- # writing, otherwise lets just kill the connection
- if request.startedWriting:
- if request.transport:
- try:
- request.transport.abortConnection()
- except Exception:
- # abortConnection throws if the connection is already closed
- pass
- else:
- respond_with_json(
- request,
- code,
- e.error_dict(),
- send_cors=True,
- pretty_print=_request_user_agent_is_curl(request),
- )
-
- except Exception:
- # failure.Failure() fishes the original Failure out
- # of our stack, and thus gives us a sensible stack
- # trace.
- f = failure.Failure()
- logger.error(
- "Failed handle request via %r: %r",
- request.request_metrics.name,
- request,
- exc_info=(f.type, f.value, f.getTracebackObject()),
- )
- # Only respond with an error response if we haven't already started
- # writing, otherwise lets just kill the connection
- if request.startedWriting:
- if request.transport:
- try:
- request.transport.abortConnection()
- except Exception:
- # abortConnection throws if the connection is already closed
- pass
- else:
- respond_with_json(
- request,
- 500,
- {"error": "Internal server error", "errcode": Codes.UNKNOWN},
- send_cors=True,
- pretty_print=_request_user_agent_is_curl(request),
- )
-
- return wrap_async_request_handler(wrapped_request_handler)
-
-
-TV = TypeVar("TV")
-
-
-def wrap_html_request_handler(
- h: Callable[[TV, SynapseRequest], Awaitable]
-) -> Callable[[TV, SynapseRequest], Awaitable[None]]:
- """Wraps a request handler method with exception handling.
+ if f.check(SynapseError):
+ error_code = f.value.code
+ error_dict = f.value.error_dict()
- Also does the wrapping with request.processing as per wrap_async_request_handler.
-
- The handler method must have a signature of "handle_foo(self, request)",
- where "request" must be a SynapseRequest.
- """
+ logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg)
+ else:
+ error_code = 500
+ error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN}
- async def wrapped_request_handler(self, request):
- try:
- await h(self, request)
- except Exception:
- f = failure.Failure()
- return_html_error(f, request, HTML_ERROR_TEMPLATE)
+ logger.error(
+ "Failed handle request via %r: %r",
+ request.request_metrics.name,
+ request,
+ exc_info=(f.type, f.value, f.getTracebackObject()),
+ )
- return wrap_async_request_handler(wrapped_request_handler)
+ # Only respond with an error response if we haven't already started writing,
+ # otherwise lets just kill the connection
+ if request.startedWriting:
+ if request.transport:
+ try:
+ request.transport.abortConnection()
+ except Exception:
+ # abortConnection throws if the connection is already closed
+ pass
+ else:
+ respond_with_json(
+ request,
+ error_code,
+ error_dict,
+ send_cors=True,
+ pretty_print=_request_user_agent_is_curl(request),
+ )
def return_html_error(
@@ -249,7 +194,113 @@ class HttpServer(object):
pass
-class JsonResource(HttpServer, resource.Resource):
+class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
+ """Base class for resources that have async handlers.
+
+ Sub classes can either implement `_async_render_<METHOD>` to handle
+ requests by method, or override `_async_render` to handle all requests.
+
+ Args:
+ extract_context: Whether to attempt to extract the opentracing
+ context from the request the servlet is handling.
+ """
+
+ def __init__(self, extract_context=False):
+ super().__init__()
+
+ self._extract_context = extract_context
+
+ def render(self, request):
+ """ This gets called by twisted every time someone sends us a request.
+ """
+ defer.ensureDeferred(self._async_render_wrapper(request))
+ return NOT_DONE_YET
+
+ @wrap_async_request_handler
+ async def _async_render_wrapper(self, request):
+ """This is a wrapper that delegates to `_async_render` and handles
+ exceptions, return values, metrics, etc.
+ """
+ try:
+ request.request_metrics.name = self.__class__.__name__
+
+ with trace_servlet(request, self._extract_context):
+ callback_return = await self._async_render(request)
+
+ if callback_return is not None:
+ code, response = callback_return
+ self._send_response(request, code, response)
+ except Exception:
+ # failure.Failure() fishes the original Failure out
+ # of our stack, and thus gives us a sensible stack
+ # trace.
+ f = failure.Failure()
+ self._send_error_response(f, request)
+
+ async def _async_render(self, request):
+ """Delegates to `_async_render_<METHOD>` methods, or returns a 400 if
+ no appropriate method exists. Can be overriden in sub classes for
+ different routing.
+ """
+
+ method_handler = getattr(
+ self, "_async_render_%s" % (request.method.decode("ascii"),), None
+ )
+ if method_handler:
+ raw_callback_return = method_handler(request)
+
+ # Is it synchronous? We'll allow this for now.
+ if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
+ callback_return = await raw_callback_return
+ else:
+ callback_return = raw_callback_return
+
+ return callback_return
+
+ _unrecognised_request_handler(request)
+
+ @abc.abstractmethod
+ def _send_response(
+ self, request: SynapseRequest, code: int, response_object: Any,
+ ) -> None:
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def _send_error_response(
+ self, f: failure.Failure, request: SynapseRequest,
+ ) -> None:
+ raise NotImplementedError()
+
+
+class DirectServeJsonResource(_AsyncResource):
+ """A resource that will call `self._async_on_<METHOD>` on new requests,
+ formatting responses and errors as JSON.
+ """
+
+ def _send_response(
+ self, request, code, response_object,
+ ):
+ """Implements _AsyncResource._send_response
+ """
+ # TODO: Only enable CORS for the requests that need it.
+ respond_with_json(
+ request,
+ code,
+ response_object,
+ send_cors=True,
+ pretty_print=_request_user_agent_is_curl(request),
+ canonical_json=self.canonical_json,
+ )
+
+ def _send_error_response(
+ self, f: failure.Failure, request: SynapseRequest,
+ ) -> None:
+ """Implements _AsyncResource._send_error_response
+ """
+ return_json_error(f, request)
+
+
+class JsonResource(DirectServeJsonResource):
""" This implements the HttpServer interface and provides JSON support for
Resources.
@@ -269,17 +320,15 @@ class JsonResource(HttpServer, resource.Resource):
"_PathEntry", ["pattern", "callback", "servlet_classname"]
)
- def __init__(self, hs, canonical_json=True):
- resource.Resource.__init__(self)
+ def __init__(self, hs, canonical_json=True, extract_context=False):
+ super().__init__(extract_context)
self.canonical_json = canonical_json
self.clock = hs.get_clock()
self.path_regexs = {}
self.hs = hs
- def register_paths(
- self, method, path_patterns, callback, servlet_classname, trace=True
- ):
+ def register_paths(self, method, path_patterns, callback, servlet_classname):
"""
Registers a request handler against a regular expression. Later request URLs are
checked against these regular expressions in order to identify an appropriate
@@ -295,37 +344,42 @@ class JsonResource(HttpServer, resource.Resource):
servlet_classname (str): The name of the handler to be used in prometheus
and opentracing logs.
-
- trace (bool): Whether we should start a span to trace the servlet.
"""
method = method.encode("utf-8") # method is bytes on py3
- if trace:
- # We don't extract the context from the servlet because we can't
- # trust the sender
- callback = trace_servlet(servlet_classname)(callback)
-
for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append(
self._PathEntry(path_pattern, callback, servlet_classname)
)
- def render(self, request):
- """ This gets called by twisted every time someone sends us a request.
+ def _get_handler_for_request(
+ self, request: SynapseRequest
+ ) -> Tuple[Callable, str, Dict[str, str]]:
+ """Finds a callback method to handle the given request.
+
+ Returns:
+ A tuple of the callback to use, the name of the servlet, and the
+ key word arguments to pass to the callback
"""
- defer.ensureDeferred(self._async_render(request))
- return NOT_DONE_YET
+ request_path = request.path.decode("ascii")
+
+ # Loop through all the registered callbacks to check if the method
+ # and path regex match
+ for path_entry in self.path_regexs.get(request.method, []):
+ m = path_entry.pattern.match(request_path)
+ if m:
+ # We found a match!
+ return path_entry.callback, path_entry.servlet_classname, m.groupdict()
+
+ # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
+ return _unrecognised_request_handler, "unrecognised_request_handler", {}
- @wrap_json_request_handler
async def _async_render(self, request):
- """ This gets called from render() every time someone sends us a request.
- This checks if anyone has registered a callback for that method and
- path.
- """
callback, servlet_classname, group_dict = self._get_handler_for_request(request)
- # Make sure we have a name for this handler in prometheus.
+ # Make sure we have an appopriate name for this handler in prometheus
+ # (rather than the default of JsonResource).
request.request_metrics.name = servlet_classname
# Now trigger the callback. If it returns a response, we send it
@@ -338,81 +392,42 @@ class JsonResource(HttpServer, resource.Resource):
}
)
- callback_return = callback(request, **kwargs)
+ raw_callback_return = callback(request, **kwargs)
# Is it synchronous? We'll allow this for now.
- if isinstance(callback_return, (defer.Deferred, types.CoroutineType)):
- callback_return = await callback_return
+ if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
+ callback_return = await raw_callback_return
+ else:
+ callback_return = raw_callback_return
- if callback_return is not None:
- code, response = callback_return
- self._send_response(request, code, response)
+ return callback_return
- def _get_handler_for_request(self, request):
- """Finds a callback method to handle the given request
- Args:
- request (twisted.web.http.Request):
+class DirectServeHtmlResource(_AsyncResource):
+ """A resource that will call `self._async_on_<METHOD>` on new requests,
+ formatting responses and errors as HTML.
+ """
- Returns:
- Tuple[Callable, str, dict[unicode, unicode]]: callback method, the
- label to use for that method in prometheus metrics, and the
- dict mapping keys to path components as specified in the
- handler's path match regexp.
-
- The callback will normally be a method registered via
- register_paths, so will return (possibly via Deferred) either
- None, or a tuple of (http code, response body).
- """
- request_path = request.path.decode("ascii")
-
- # Loop through all the registered callbacks to check if the method
- # and path regex match
- for path_entry in self.path_regexs.get(request.method, []):
- m = path_entry.pattern.match(request_path)
- if m:
- # We found a match!
- return path_entry.callback, path_entry.servlet_classname, m.groupdict()
-
- # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
- return _unrecognised_request_handler, "unrecognised_request_handler", {}
+ # The error template to use for this resource
+ ERROR_TEMPLATE = HTML_ERROR_TEMPLATE
def _send_response(
- self, request, code, response_json_object, response_code_message=None
+ self, request: SynapseRequest, code: int, response_object: Any,
):
- # TODO: Only enable CORS for the requests that need it.
- respond_with_json(
- request,
- code,
- response_json_object,
- send_cors=True,
- response_code_message=response_code_message,
- pretty_print=_request_user_agent_is_curl(request),
- canonical_json=self.canonical_json,
- )
-
-
-class DirectServeResource(resource.Resource):
- def render(self, request):
+ """Implements _AsyncResource._send_response
"""
- Render the request, using an asynchronous render handler if it exists.
- """
- async_render_callback_name = "_async_render_" + request.method.decode("ascii")
-
- # Try and get the async renderer
- callback = getattr(self, async_render_callback_name, None)
+ # We expect to get bytes for us to write
+ assert isinstance(response_object, bytes)
+ html_bytes = response_object
- # No async renderer for this request method.
- if not callback:
- return super().render(request)
+ respond_with_html_bytes(request, 200, html_bytes)
- resp = trace_servlet(self.__class__.__name__)(callback)(request)
-
- # If it's a coroutine, turn it into a Deferred
- if isinstance(resp, types.CoroutineType):
- defer.ensureDeferred(resp)
-
- return NOT_DONE_YET
+ def _send_error_response(
+ self, f: failure.Failure, request: SynapseRequest,
+ ) -> None:
+ """Implements _AsyncResource._send_error_response
+ """
+ return_html_error(f, request, self.ERROR_TEMPLATE)
class StaticResource(File):
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 73bef5e5ca..c6c0e623c1 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -164,12 +164,10 @@ Gotchas
than one caller? Will all of those calling functions have be in a context
with an active span?
"""
-
import contextlib
import inspect
import logging
import re
-import types
from functools import wraps
from typing import TYPE_CHECKING, Dict, Optional, Type
@@ -181,6 +179,7 @@ from twisted.internet import defer
from synapse.config import ConfigError
if TYPE_CHECKING:
+ from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
# Helper class
@@ -227,6 +226,7 @@ except ImportError:
tags = _DummyTagNames
try:
from jaeger_client import Config as JaegerConfig
+
from synapse.logging.scopecontextmanager import LogContextScopeManager
except ImportError:
JaegerConfig = None # type: ignore
@@ -793,48 +793,42 @@ def tag_args(func):
return _tag_args_inner
-def trace_servlet(servlet_name, extract_context=False):
- """Decorator which traces a serlet. It starts a span with some servlet specific
- tags such as the servlet_name and request information
+@contextlib.contextmanager
+def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
+ """Returns a context manager which traces a request. It starts a span
+ with some servlet specific tags such as the request metrics name and
+ request information.
Args:
- servlet_name (str): The name to be used for the span's operation_name
- extract_context (bool): Whether to attempt to extract the opentracing
+ request
+ extract_context: Whether to attempt to extract the opentracing
context from the request the servlet is handling.
-
"""
- def _trace_servlet_inner_1(func):
- if not opentracing:
- return func
-
- @wraps(func)
- async def _trace_servlet_inner(request, *args, **kwargs):
- request_tags = {
- "request_id": request.get_request_id(),
- tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
- tags.HTTP_METHOD: request.get_method(),
- tags.HTTP_URL: request.get_redacted_uri(),
- tags.PEER_HOST_IPV6: request.getClientIP(),
- }
-
- if extract_context:
- scope = start_active_span_from_request(
- request, servlet_name, tags=request_tags
- )
- else:
- scope = start_active_span(servlet_name, tags=request_tags)
-
- with scope:
- result = func(request, *args, **kwargs)
+ if opentracing is None:
+ yield
+ return
- if not isinstance(result, (types.CoroutineType, defer.Deferred)):
- # Some servlets aren't async and just return results
- # directly, so we handle that here.
- return result
+ request_tags = {
+ "request_id": request.get_request_id(),
+ tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
+ tags.HTTP_METHOD: request.get_method(),
+ tags.HTTP_URL: request.get_redacted_uri(),
+ tags.PEER_HOST_IPV6: request.getClientIP(),
+ }
- return await result
+ request_name = request.request_metrics.name
+ if extract_context:
+ scope = start_active_span_from_request(request, request_name, tags=request_tags)
+ else:
+ scope = start_active_span(request_name, tags=request_tags)
- return _trace_servlet_inner
+ with scope:
+ try:
+ yield
+ finally:
+ # We set the operation name again in case its changed (which happens
+ # with JsonResource).
+ scope.span.set_operation_name(request.request_metrics.name)
- return _trace_servlet_inner_1
+ scope.span.set_tag("request_tag", request.request_metrics.start_context.tag)
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 13785038ad..a9269196b3 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Dict, Optional, Set
from prometheus_client.core import REGISTRY, Counter, Gauge
from twisted.internet import defer
+from twisted.python.failure import Failure
from synapse.logging.context import LoggingContext, PreserveLoggingContext
@@ -212,7 +213,14 @@ def run_as_background_process(desc, func, *args, **kwargs):
return (yield result)
except Exception:
- logger.exception("Background process '%s' threw an exception", desc)
+ # failure.Failure() fishes the original Failure out of our stack, and
+ # thus gives us a sensible stack trace.
+ f = Failure()
+ logger.error(
+ "Background process '%s' threw an exception",
+ desc,
+ exc_info=(f.type, f.value, f.getTracebackObject()),
+ )
finally:
_background_process_in_flight_count.labels(desc).dec()
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 87c120a59c..bd41f77852 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -83,7 +83,7 @@ class _NotifierUserStream(object):
self.current_token = current_token
# The last token for which we should wake up any streams that have a
- # token that comes before it. This gets updated everytime we get poked.
+ # token that comes before it. This gets updated every time we get poked.
# We start it at the current token since if we get any streams
# that have a token from before we have no idea whether they should be
# woken up or not, so lets just wake them up.
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index ed60dbc1bf..2fac07593b 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -20,6 +20,7 @@ from prometheus_client import Counter
from twisted.internet import defer
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
+from synapse.api.constants import EventTypes
from synapse.logging import opentracing
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import PusherConfigException
@@ -305,12 +306,23 @@ class HttpPusher(object):
@defer.inlineCallbacks
def _build_notification_dict(self, event, tweaks, badge):
+ priority = "low"
+ if (
+ event.type == EventTypes.Encrypted
+ or tweaks.get("highlight")
+ or tweaks.get("sound")
+ ):
+ # HACK send our push as high priority only if it generates a sound, highlight
+ # or may do so (i.e. is encrypted so has unknown effects).
+ priority = "high"
+
if self.data.get("format") == "event_id_only":
d = {
"notification": {
"event_id": event.event_id,
"room_id": event.room_id,
"counts": {"unread": badge},
+ "prio": priority,
"devices": [
{
"app_id": self.app_id,
@@ -334,9 +346,8 @@ class HttpPusher(object):
"room_id": event.room_id,
"type": event.type,
"sender": event.user_id,
- "counts": { # -- we don't mark messages as read yet so
- # we have no way of knowing
- # Just set the badge to 1 until we have read receipts
+ "prio": priority,
+ "counts": {
"unread": badge,
# 'missed_calls': 2
},
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 8e0d3a416d..2d79ada189 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -16,7 +16,7 @@
import logging
import re
-from typing import Pattern
+from typing import Any, Dict, List, Pattern, Union
from synapse.events import EventBase
from synapse.types import UserID
@@ -72,13 +72,36 @@ def _test_ineq_condition(condition, number):
return False
-def tweaks_for_actions(actions):
+def tweaks_for_actions(actions: List[Union[str, Dict]]) -> Dict[str, Any]:
+ """
+ Converts a list of actions into a `tweaks` dict (which can then be passed to
+ the push gateway).
+
+ This function ignores all actions other than `set_tweak` actions, and treats
+ absent `value`s as `True`, which agrees with the only spec-defined treatment
+ of absent `value`s (namely, for `highlight` tweaks).
+
+ Args:
+ actions: list of actions
+ e.g. [
+ {"set_tweak": "a", "value": "AAA"},
+ {"set_tweak": "b", "value": "BBB"},
+ {"set_tweak": "highlight"},
+ "notify"
+ ]
+
+ Returns:
+ dictionary of tweaks for those actions
+ e.g. {"a": "AAA", "b": "BBB", "highlight": True}
+ """
tweaks = {}
for a in actions:
if not isinstance(a, dict):
continue
- if "set_tweak" in a and "value" in a:
- tweaks[a["set_tweak"]] = a["value"]
+ if "set_tweak" in a:
+ # value is allowed to be absent in which case the value assumed
+ # should be True.
+ tweaks[a["set_tweak"]] = a.get("value", True)
return tweaks
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index b1cac901eb..8cfcdb0573 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -66,7 +66,7 @@ REQUIREMENTS = [
"pymacaroons>=0.13.0",
"msgpack>=0.5.2",
"phonenumbers>=8.2.0",
- "prometheus_client>=0.0.18,<0.8.0",
+ "prometheus_client>=0.0.18,<0.9.0",
# we use attr.validators.deep_iterable, which arrived in 19.1.0
"attrs>=19.1.0",
"netaddr>=0.7.18",
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index 19b69e0e11..5ef1c6c1dc 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -30,7 +30,8 @@ REPLICATION_PREFIX = "/_synapse/replication"
class ReplicationRestResource(JsonResource):
def __init__(self, hs):
- JsonResource.__init__(self, hs, canonical_json=False)
+ # We enable extracting jaeger contexts here as these are internal APIs.
+ super().__init__(hs, canonical_json=False, extract_context=True)
self.register_servlets(hs)
def register_servlets(self, hs):
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 9caf1e80c1..fb0dd04f88 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -28,11 +28,7 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
-from synapse.logging.opentracing import (
- inject_active_span_byte_dict,
- trace,
- trace_servlet,
-)
+from synapse.logging.opentracing import inject_active_span_byte_dict, trace
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
@@ -96,11 +92,11 @@ class ReplicationEndpoint(object):
# assert here that sub classes don't try and use the name.
assert (
"instance_name" not in self.PATH_ARGS
- ), "`instance_name` is a reserved paramater name"
+ ), "`instance_name` is a reserved parameter name"
assert (
"instance_name"
not in signature(self.__class__._serialize_payload).parameters
- ), "`instance_name` is a reserved paramater name"
+ ), "`instance_name` is a reserved parameter name"
assert self.METHOD in ("PUT", "POST", "GET")
@@ -240,11 +236,8 @@ class ReplicationEndpoint(object):
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
- handler = trace_servlet(self.__class__.__name__, extract_context=True)(handler)
- # We don't let register paths trace this servlet using the default tracing
- # options because we wish to extract the context explicitly.
http_server.register_paths(
- method, [pattern], handler, self.__class__.__name__, trace=False
+ method, [pattern], handler, self.__class__.__name__,
)
def _cached_handler(self, request, txn_id, **kwargs):
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index a7174c4a8f..63ef6eb7be 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -14,11 +14,11 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
-from synapse.types import Requester, UserID
+from synapse.types import JsonDict, Requester, UserID
from synapse.util.distributor import user_joined_room, user_left_room
if TYPE_CHECKING:
@@ -88,49 +88,54 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
- """Rejects the invite for the user and room.
+ """Rejects an out-of-band invite we have received from a remote server
Request format:
- POST /_synapse/replication/remote_reject_invite/:room_id/:user_id
+ POST /_synapse/replication/remote_reject_invite/:event_id
{
+ "txn_id": ...,
"requester": ...,
- "remote_room_hosts": [...],
"content": { ... }
}
"""
NAME = "remote_reject_invite"
- PATH_ARGS = ("room_id", "user_id")
+ PATH_ARGS = ("invite_event_id",)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super(ReplicationRemoteRejectInviteRestServlet, self).__init__(hs)
- self.federation_handler = hs.get_handlers().federation_handler
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.member_handler = hs.get_room_member_handler()
@staticmethod
- def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content):
+ def _serialize_payload( # type: ignore
+ invite_event_id: str,
+ txn_id: Optional[str],
+ requester: Requester,
+ content: JsonDict,
+ ):
"""
Args:
- requester(Requester)
- room_id (str)
- user_id (str)
- remote_room_hosts (list[str]): Servers to try and reject via
+ invite_event_id: ID of the invite to be rejected
+ txn_id: optional transaction ID supplied by the client
+ requester: user making the rejection request, according to the access token
+ content: additional content to include in the rejection event.
+ Normally an empty dict.
"""
return {
+ "txn_id": txn_id,
"requester": requester.serialize(),
- "remote_room_hosts": remote_room_hosts,
"content": content,
}
- async def _handle_request(self, request, room_id, user_id):
+ async def _handle_request(self, request, invite_event_id):
content = parse_json_object_from_request(request)
- remote_room_hosts = content["remote_room_hosts"]
+ txn_id = content["txn_id"]
event_content = content["content"]
requester = Requester.deserialize(self.store, content["requester"])
@@ -138,60 +143,14 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
if requester.user:
request.authenticated_entity = requester.user.to_string()
- logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id)
-
- try:
- event, stream_id = await self.federation_handler.do_remotely_reject_invite(
- remote_room_hosts, room_id, user_id, event_content,
- )
- event_id = event.event_id
- except Exception as e:
- # if we were unable to reject the exception, just mark
- # it as rejected on our end and plough ahead.
- #
- # The 'except' clause is very broad, but we need to
- # capture everything from DNS failures upwards
- #
- logger.warning("Failed to reject invite: %s", e)
-
- stream_id = await self.member_handler.locally_reject_invite(
- user_id, room_id
- )
- event_id = None
+ # hopefully we're now on the master, so this won't recurse!
+ event_id, stream_id = await self.member_handler.remote_reject_invite(
+ invite_event_id, txn_id, requester, event_content,
+ )
return 200, {"event_id": event_id, "stream_id": stream_id}
-class ReplicationLocallyRejectInviteRestServlet(ReplicationEndpoint):
- """Rejects the invite for the user and room locally.
-
- Request format:
-
- POST /_synapse/replication/locally_reject_invite/:room_id/:user_id
-
- {}
- """
-
- NAME = "locally_reject_invite"
- PATH_ARGS = ("room_id", "user_id")
-
- def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
-
- self.member_handler = hs.get_room_member_handler()
-
- @staticmethod
- def _serialize_payload(room_id, user_id):
- return {}
-
- async def _handle_request(self, request, room_id, user_id):
- logger.info("locally_reject_invite: %s out of room: %s", user_id, room_id)
-
- stream_id = await self.member_handler.locally_reject_invite(user_id, room_id)
-
- return 200, {"stream_id": stream_id}
-
-
class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
"""Notifies that a user has joined or left the room
@@ -245,4 +204,3 @@ def register_servlets(hs, http_server):
ReplicationRemoteJoinRestServlet(hs).register(http_server)
ReplicationRemoteRejectInviteRestServlet(hs).register(http_server)
ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server)
- ReplicationLocallyRejectInviteRestServlet(hs).register(http_server)
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 9db6c62bc7..525b94fd87 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -16,6 +16,7 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
from synapse.storage.data_stores.main.tags import TagsWorkerStore
from synapse.storage.database import Database
@@ -39,12 +40,12 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
return self._account_data_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "tag_account_data":
+ if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(token)
for row in rows:
self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
- elif stream_name == "account_data":
+ elif stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(token)
for row in rows:
if not row.room_id:
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 6e7fd259d4..bd394f6b00 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -15,6 +15,7 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import ToDeviceStream
from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore
from synapse.storage.database import Database
from synapse.util.caches.expiringcache import ExpiringCache
@@ -44,7 +45,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
)
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "to_device":
+ if stream_name == ToDeviceStream.NAME:
self._device_inbox_id_gen.advance(token)
for row in rows:
if row.entity.startswith("@"):
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 1851e7d525..5d210fa3a1 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -15,6 +15,7 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import GroupServerStream
from synapse.storage.data_stores.main.group_server import GroupServerWorkerStore
from synapse.storage.database import Database
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -38,7 +39,7 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
return self._group_updates_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "groups":
+ if stream_name == GroupServerStream.NAME:
self._group_updates_id_gen.advance(token)
for row in rows:
self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index 4e0124842d..2938cb8e43 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.replication.tcp.streams import PresenceStream
from synapse.storage import DataStore
from synapse.storage.data_stores.main.presence import PresenceStore
from synapse.storage.database import Database
@@ -42,7 +43,7 @@ class SlavedPresenceStore(BaseSlavedStore):
return self._presence_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "presence":
+ if stream_name == PresenceStream.NAME:
self._presence_id_gen.advance(token)
for row in rows:
self.presence_stream_cache.entity_has_changed(row.user_id, token)
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 6adb19463a..23ec1c5b11 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
from .events import SlavedEventStore
@@ -30,7 +31,7 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
return self._push_rules_stream_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "push_rules":
+ if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(token)
for row in rows:
self.get_push_rules_for_user.invalidate((row.user_id,))
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index cb78b49acb..ff449f3658 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.replication.tcp.streams import PushersStream
from synapse.storage.data_stores.main.pusher import PusherWorkerStore
from synapse.storage.database import Database
@@ -32,6 +33,6 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
return self._pushers_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "pushers":
+ if stream_name == PushersStream.NAME:
self._pushers_id_gen.advance(token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index be716cc558..6982686eb5 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -14,20 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
from synapse.storage.database import Database
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
-# So, um, we want to borrow a load of functions intended for reading from
-# a DataStore, but we don't want to take functions that either write to the
-# DataStore or are cached and don't have cache invalidation logic.
-#
-# Rather than write duplicate versions of those functions, or lift them to
-# a common base class, we going to grab the underlying __func__ object from
-# the method descriptor on the DataStore and chuck them into our class.
-
class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def __init__(self, database: Database, db_conn, hs):
@@ -52,7 +45,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
self.get_receipts_for_room.invalidate((room_id, receipt_type))
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "receipts":
+ if stream_name == ReceiptsStream.NAME:
self._receipts_id_gen.advance(token)
for row in rows:
self.invalidate_caches_for_receipt(
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 8873bf37e5..8710207ada 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.replication.tcp.streams import PublicRoomsStream
from synapse.storage.data_stores.main.room import RoomWorkerStore
from synapse.storage.database import Database
@@ -31,7 +32,7 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
return self._public_room_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "public_rooms":
+ if stream_name == PublicRoomsStream.NAME:
self._public_room_id_gen.advance(token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/tcp/__init__.py b/synapse/replication/tcp/__init__.py
index 523a1358d4..1b8718b11d 100644
--- a/synapse/replication/tcp/__init__.py
+++ b/synapse/replication/tcp/__init__.py
@@ -25,7 +25,7 @@ Structure of the module:
* command.py - the definitions of all the valid commands
* protocol.py - the TCP protocol classes
* resource.py - handles streaming stream updates to replications
- * streams/ - the definitons of all the valid streams
+ * streams/ - the definitions of all the valid streams
The general interaction of the classes are:
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index df29732f51..4985e40b1f 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -33,8 +33,8 @@ from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure
if TYPE_CHECKING:
- from synapse.server import HomeServer
from synapse.replication.tcp.handler import ReplicationCommandHandler
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index ea5937a20c..ccc7f1f0d1 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -18,18 +18,11 @@ The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are
allowed to be sent by which side.
"""
import abc
+import json
import logging
-import platform
from typing import Tuple, Type
-if platform.python_implementation() == "PyPy":
- import json
-
- _json_encoder = json.JSONEncoder()
-else:
- import simplejson as json # type: ignore[no-redef] # noqa: F821
-
- _json_encoder = json.JSONEncoder(namedtuple_as_object=False) # type: ignore[call-arg] # noqa: F821
+_json_encoder = json.JSONEncoder()
logger = logging.getLogger(__name__)
@@ -54,7 +47,7 @@ class Command(metaclass=abc.ABCMeta):
@abc.abstractmethod
def to_line(self) -> str:
- """Serialises the comamnd for the wire. Does not include the command
+ """Serialises the command for the wire. Does not include the command
prefix.
"""
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index e6a2e2598b..55b3b79008 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
@@ -149,10 +148,11 @@ class ReplicationCommandHandler:
using TCP.
"""
if hs.config.redis.redis_enabled:
+ import txredisapi
+
from synapse.replication.tcp.redis import (
RedisDirectTcpReplicationClientFactory,
)
- import txredisapi
logger.info(
"Connecting to redis (host=%r port=%r)",
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 4198eece71..ca47f5cc88 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -317,7 +317,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def _queue_command(self, cmd):
"""Queue the command until the connection is ready to write to again.
"""
- logger.debug("[%s] Queing as conn %r, cmd: %r", self.id(), self.state, cmd)
+ logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd)
self.pending_commands.append(cmd)
if len(self.pending_commands) > self.max_line_buffer:
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index e776b63183..0a7e7f67be 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -177,7 +177,7 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
Args:
hs
outbound_redis_connection: A connection to redis that will be used to
- send outbound commands (this is seperate to the redis connection
+ send outbound commands (this is separate to the redis connection
used to subscribe).
"""
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index f196eff072..9076bbe9f1 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -198,26 +198,6 @@ def current_token_without_instance(
return lambda instance_name: current_token()
-def db_query_to_update_function(
- query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
-) -> UpdateFunction:
- """Wraps a db query function which returns a list of rows to make it
- suitable for use as an `update_function` for the Stream class
- """
-
- async def update_function(instance_name, from_token, upto_token, limit):
- rows = await query_function(from_token, upto_token, limit)
- updates = [(row[0], row[1:]) for row in rows]
- limited = False
- if len(updates) >= limit:
- upto_token = updates[-1][0]
- limited = True
-
- return updates, upto_token, limited
-
- return update_function
-
-
def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
"""Makes a suitable function for use as an `update_function` that queries
the master process for updates.
@@ -393,7 +373,7 @@ class PushersStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_pushers_stream_token),
- db_query_to_update_function(store.get_all_updated_pushers_rows),
+ store.get_all_updated_pushers_rows,
)
@@ -421,26 +401,12 @@ class CachesStream(Stream):
ROW_TYPE = CachesStreamRow
def __init__(self, hs):
- self.store = hs.get_datastore()
+ store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- self.store.get_cache_stream_token,
- self._update_function,
- )
-
- async def _update_function(
- self, instance_name: str, from_token: int, upto_token: int, limit: int
- ):
- rows = await self.store.get_all_updated_caches(
- instance_name, from_token, upto_token, limit
+ store.get_cache_stream_token,
+ store.get_all_updated_caches,
)
- updates = [(row[0], row[1:]) for row in rows]
- limited = False
- if len(updates) >= limit:
- upto_token = updates[-1][0]
- limited = True
-
- return updates, upto_token, limited
class PublicRoomsStream(Stream):
@@ -465,7 +431,7 @@ class PublicRoomsStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_current_public_room_stream_id),
- db_query_to_update_function(store.get_all_new_public_rooms),
+ store.get_all_new_public_rooms,
)
@@ -486,7 +452,7 @@ class DeviceListsStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_device_stream_token),
- db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
+ store.get_all_device_list_changes_for_remotes,
)
@@ -504,7 +470,7 @@ class ToDeviceStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_to_device_stream_token),
- db_query_to_update_function(store.get_all_new_device_messages),
+ store.get_all_new_device_messages,
)
@@ -524,7 +490,7 @@ class TagAccountDataStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_max_account_data_stream_id),
- db_query_to_update_function(store.get_all_updated_tags),
+ store.get_all_updated_tags,
)
@@ -612,7 +578,7 @@ class GroupServerStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_group_stream_token),
- db_query_to_update_function(store.get_all_groups_changes),
+ store.get_all_groups_changes,
)
@@ -630,7 +596,5 @@ class UserSignatureStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_device_stream_token),
- db_query_to_update_function(
- store.get_all_user_signature_changes_for_remotes
- ),
+ store.get_all_user_signature_changes_for_remotes,
)
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index f370390331..1c2a4cce7f 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import heapq
from collections import Iterable
from typing import List, Tuple, Type
@@ -22,7 +21,6 @@ import attr
from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
-
"""Handling of the 'events' replication stream
This stream contains rows of various types. Each row therefore contains a 'type'
@@ -64,7 +62,7 @@ class BaseEventsStreamRow(object):
Specifies how to identify, serialize and deserialize the different types.
"""
- # Unique string that ids the type. Must be overriden in sub classes.
+ # Unique string that ids the type. Must be overridden in sub classes.
TypeId = None # type: str
@classmethod
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index bf0f9bd077..64d5c58b65 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
@@ -26,8 +27,9 @@ from synapse.http.servlet import (
from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
from synapse.util.msisdn import phone_number_to_msisdn
+from synapse.util.threepids import canonicalise_email
logger = logging.getLogger(__name__)
@@ -113,7 +115,7 @@ class LoginRestServlet(RestServlet):
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
)
- def on_GET(self, request):
+ def on_GET(self, request: SynapseRequest):
flows = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
@@ -141,10 +143,10 @@ class LoginRestServlet(RestServlet):
return 200, {"flows": flows}
- def on_OPTIONS(self, request):
+ def on_OPTIONS(self, request: SynapseRequest):
return 200, {}
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest):
self._address_ratelimiter.ratelimit(request.getClientIP())
login_submission = parse_json_object_from_request(request)
@@ -153,9 +155,9 @@ class LoginRestServlet(RestServlet):
login_submission["type"] == LoginRestServlet.JWT_TYPE
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
):
- result = await self.do_jwt_login(login_submission)
+ result = await self._do_jwt_login(login_submission)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
- result = await self.do_token_login(login_submission)
+ result = await self._do_token_login(login_submission)
else:
result = await self._do_other_login(login_submission)
except KeyError:
@@ -166,14 +168,14 @@ class LoginRestServlet(RestServlet):
result["well_known"] = well_known_data
return 200, result
- async def _do_other_login(self, login_submission):
+ async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
"""Handle non-token/saml/jwt logins
Args:
login_submission:
Returns:
- dict: HTTP response
+ HTTP response
"""
# Log the request we got, but only certain fields to minimise the chance of
# logging someone's password (even if they accidentally put it in the wrong
@@ -206,11 +208,14 @@ class LoginRestServlet(RestServlet):
if medium is None or address is None:
raise SynapseError(400, "Invalid thirdparty identifier")
+ # For emails, canonicalise the address.
+ # We store all email addresses canonicalised in the DB.
+ # (See add_threepid in synapse/handlers/auth.py)
if medium == "email":
- # For emails, transform the address to lowercase.
- # We store all email addreses as lowercase in the DB.
- # (See add_threepid in synapse/handlers/auth.py)
- address = address.lower()
+ try:
+ address = canonicalise_email(address)
+ except ValueError as e:
+ raise SynapseError(400, str(e))
# We also apply account rate limiting using the 3PID as a key, as
# otherwise using 3PID bypasses the ratelimiting based on user ID.
@@ -288,25 +293,30 @@ class LoginRestServlet(RestServlet):
return result
async def _complete_login(
- self, user_id, login_submission, callback=None, create_non_existent_users=False
- ):
+ self,
+ user_id: str,
+ login_submission: JsonDict,
+ callback: Optional[
+ Callable[[Dict[str, str]], Awaitable[Dict[str, str]]]
+ ] = None,
+ create_non_existent_users: bool = False,
+ ) -> Dict[str, str]:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
- all succesful logins.
+ all successful logins.
- Applies the ratelimiting for succesful login attempts against an
+ Applies the ratelimiting for successful login attempts against an
account.
Args:
- user_id (str): ID of the user to register.
- login_submission (dict): Dictionary of login information.
- callback (func|None): Callback function to run after registration.
- create_non_existent_users (bool): Whether to create the user if
- they don't exist. Defaults to False.
+ user_id: ID of the user to register.
+ login_submission: Dictionary of login information.
+ callback: Callback function to run after registration.
+ create_non_existent_users: Whether to create the user if they don't
+ exist. Defaults to False.
Returns:
- result (Dict[str,str]): Dictionary of account information after
- successful registration.
+ result: Dictionary of account information after successful registration.
"""
# Before we actually log them in we check if they've already logged in
@@ -340,7 +350,7 @@ class LoginRestServlet(RestServlet):
return result
- async def do_token_login(self, login_submission):
+ async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]:
token = login_submission["token"]
auth_handler = self.auth_handler
user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
@@ -350,7 +360,7 @@ class LoginRestServlet(RestServlet):
result = await self._complete_login(user_id, login_submission)
return result
- async def do_jwt_login(self, login_submission):
+ async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
token = login_submission.get("token", None)
if token is None:
raise LoginError(
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 46811abbfa..f40ed82142 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -217,10 +217,8 @@ class RoomStateEventRestServlet(TransactionRestServlet):
)
event_id = event.event_id
- ret = {} # type: dict
- if event_id:
- set_tag("event_id", event_id)
- ret = {"event_id": event_id}
+ set_tag("event_id", event_id)
+ ret = {"event_id": event_id}
return 200, ret
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index 747d46eac2..50277c6cf6 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -50,7 +50,7 @@ class VoipRestServlet(RestServlet):
# We need to use standard padded base64 encoding here
# encode_base64 because we need to add the standard padding to get the
# same result as the TURN server.
- password = base64.b64encode(mac.digest())
+ password = base64.b64encode(mac.digest()).decode("ascii")
elif turnUris and turnUsername and turnPassword and userLifetime:
username = turnUsername
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 182a308eef..3767a809a4 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -30,7 +30,7 @@ from synapse.http.servlet import (
from synapse.push.mailer import Mailer, load_jinja2_templates
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import assert_valid_client_secret, random_string
-from synapse.util.threepids import check_3pid_allowed
+from synapse.util.threepids import canonicalise_email, check_3pid_allowed
from ._base import client_patterns, interactive_auth_handler
@@ -83,7 +83,15 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
client_secret = body["client_secret"]
assert_valid_client_secret(client_secret)
- email = body["email"]
+ # Canonicalise the email address. The addresses are all stored canonicalised
+ # in the database. This allows the user to reset his password without having to
+ # know the exact spelling (eg. upper and lower case) of address in the database.
+ # Stored in the database "foo@bar.com"
+ # User requests with "FOO@bar.com" would raise a Not Found error
+ try:
+ email = canonicalise_email(body["email"])
+ except ValueError as e:
+ raise SynapseError(400, str(e))
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
@@ -94,6 +102,10 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ # The email will be sent to the stored address.
+ # This avoids a potential account hijack by requesting a password reset to
+ # an email address which is controlled by the attacker but which, after
+ # canonicalisation, matches the one in our database.
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"email", email
)
@@ -274,10 +286,13 @@ class PasswordRestServlet(RestServlet):
if "medium" not in threepid or "address" not in threepid:
raise SynapseError(500, "Malformed threepid")
if threepid["medium"] == "email":
- # For emails, transform the address to lowercase.
- # We store all email addreses as lowercase in the DB.
+ # For emails, canonicalise the address.
+ # We store all email addresses canonicalised in the DB.
# (See add_threepid in synapse/handlers/auth.py)
- threepid["address"] = threepid["address"].lower()
+ try:
+ threepid["address"] = canonicalise_email(threepid["address"])
+ except ValueError as e:
+ raise SynapseError(400, str(e))
# if using email, we must know about the email they're authing with!
threepid_user_id = await self.datastore.get_user_id_by_threepid(
threepid["medium"], threepid["address"]
@@ -392,7 +407,16 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
client_secret = body["client_secret"]
assert_valid_client_secret(client_secret)
- email = body["email"]
+ # Canonicalise the email address. The addresses are all stored canonicalised
+ # in the database.
+ # This ensures that the validation email is sent to the canonicalised address
+ # as it will later be entered into the database.
+ # Otherwise the email will be sent to "FOO@bar.com" and stored as
+ # "foo@bar.com" in database.
+ try:
+ email = canonicalise_email(body["email"])
+ except ValueError as e:
+ raise SynapseError(400, str(e))
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
@@ -403,9 +427,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
- existing_user_id = await self.store.get_user_id_by_threepid(
- "email", body["email"]
- )
+ existing_user_id = await self.store.get_user_id_by_threepid("email", email)
if existing_user_id is not None:
if self.config.request_token_inhibit_3pid_errors:
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 56a451c42f..370742ce59 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -47,7 +47,7 @@ from synapse.push.mailer import load_jinja2_templates
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.stringutils import assert_valid_client_secret, random_string
-from synapse.util.threepids import check_3pid_allowed
+from synapse.util.threepids import canonicalise_email, check_3pid_allowed
from ._base import client_patterns, interactive_auth_handler
@@ -116,7 +116,14 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
client_secret = body["client_secret"]
assert_valid_client_secret(client_secret)
- email = body["email"]
+ # For emails, canonicalise the address.
+ # We store all email addresses canonicalised in the DB.
+ # (See on_POST in EmailThreepidRequestTokenRestServlet
+ # in synapse/rest/client/v2_alpha/account.py)
+ try:
+ email = canonicalise_email(body["email"])
+ except ValueError as e:
+ raise SynapseError(400, str(e))
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
@@ -128,7 +135,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
)
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
- "email", body["email"]
+ "email", email
)
if existing_user_id is not None:
@@ -552,6 +559,15 @@ class RegisterRestServlet(RestServlet):
if login_type in auth_result:
medium = auth_result[login_type]["medium"]
address = auth_result[login_type]["address"]
+ # For emails, canonicalise the address.
+ # We store all email addresses canonicalised in the DB.
+ # (See on_POST in EmailThreepidRequestTokenRestServlet
+ # in synapse/rest/client/v2_alpha/account.py)
+ if medium == "email":
+ try:
+ address = canonicalise_email(address)
+ except ValueError as e:
+ raise SynapseError(400, str(e))
existing_user_id = await self.store.get_user_id_by_threepid(
medium, address
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index 0a890c98cb..4386eb4e72 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -26,11 +26,7 @@ from twisted.internet import defer
from synapse.api.errors import NotFoundError, StoreError, SynapseError
from synapse.config import ConfigError
-from synapse.http.server import (
- DirectServeResource,
- respond_with_html,
- wrap_html_request_handler,
-)
+from synapse.http.server import DirectServeHtmlResource, respond_with_html
from synapse.http.servlet import parse_string
from synapse.types import UserID
@@ -48,7 +44,7 @@ else:
return a == b
-class ConsentResource(DirectServeResource):
+class ConsentResource(DirectServeHtmlResource):
"""A twisted Resource to display a privacy policy and gather consent to it
When accessed via GET, returns the privacy policy via a template.
@@ -119,7 +115,6 @@ class ConsentResource(DirectServeResource):
self._hmac_secret = hs.config.form_secret.encode("utf-8")
- @wrap_html_request_handler
async def _async_render_GET(self, request):
"""
Args:
@@ -160,7 +155,6 @@ class ConsentResource(DirectServeResource):
except TemplateNotFound:
raise NotFoundError("Unknown policy version")
- @wrap_html_request_handler
async def _async_render_POST(self, request):
"""
Args:
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index ab671f7334..e149ac1733 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -20,17 +20,13 @@ from signedjson.sign import sign_json
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher
-from synapse.http.server import (
- DirectServeResource,
- respond_with_json_bytes,
- wrap_json_request_handler,
-)
+from synapse.http.server import DirectServeJsonResource, respond_with_json_bytes
from synapse.http.servlet import parse_integer, parse_json_object_from_request
logger = logging.getLogger(__name__)
-class RemoteKey(DirectServeResource):
+class RemoteKey(DirectServeJsonResource):
"""HTTP resource for retreiving the TLS certificate and NACL signature
verification keys for a collection of servers. Checks that the reported
X.509 TLS certificate matches the one used in the HTTPS connection. Checks
@@ -92,13 +88,14 @@ class RemoteKey(DirectServeResource):
isLeaf = True
def __init__(self, hs):
+ super().__init__()
+
self.fetcher = ServerKeyFetcher(hs)
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
self.config = hs.config
- @wrap_json_request_handler
async def _async_render_GET(self, request):
if len(request.postpath) == 1:
(server,) = request.postpath
@@ -115,7 +112,6 @@ class RemoteKey(DirectServeResource):
await self.query_keys(request, query, query_remote_on_cache_miss=True)
- @wrap_json_request_handler
async def _async_render_POST(self, request):
content = parse_json_object_from_request(request)
diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py
index 9f747de263..68dd2a1c8a 100644
--- a/synapse/rest/media/v1/config_resource.py
+++ b/synapse/rest/media/v1/config_resource.py
@@ -14,16 +14,10 @@
# limitations under the License.
#
-from twisted.web.server import NOT_DONE_YET
+from synapse.http.server import DirectServeJsonResource, respond_with_json
-from synapse.http.server import (
- DirectServeResource,
- respond_with_json,
- wrap_json_request_handler,
-)
-
-class MediaConfigResource(DirectServeResource):
+class MediaConfigResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs):
@@ -33,11 +27,9 @@ class MediaConfigResource(DirectServeResource):
self.auth = hs.get_auth()
self.limits_dict = {"m.upload.size": config.max_upload_size}
- @wrap_json_request_handler
async def _async_render_GET(self, request):
await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True)
- def render_OPTIONS(self, request):
+ async def _async_render_OPTIONS(self, request):
respond_with_json(request, 200, {}, send_cors=True)
- return NOT_DONE_YET
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index 24d3ae5bbc..d3d8457303 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -15,18 +15,14 @@
import logging
import synapse.http.servlet
-from synapse.http.server import (
- DirectServeResource,
- set_cors_headers,
- wrap_json_request_handler,
-)
+from synapse.http.server import DirectServeJsonResource, set_cors_headers
from ._base import parse_media_id, respond_404
logger = logging.getLogger(__name__)
-class DownloadResource(DirectServeResource):
+class DownloadResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo):
@@ -34,10 +30,6 @@ class DownloadResource(DirectServeResource):
self.media_repo = media_repo
self.server_name = hs.hostname
- # this is expected by @wrap_json_request_handler
- self.clock = hs.get_clock()
-
- @wrap_json_request_handler
async def _async_render_GET(self, request):
set_cors_headers(request)
request.setHeader(
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index b4645cd608..e52c86c798 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -34,10 +34,9 @@ from twisted.internet.error import DNSLookupError
from synapse.api.errors import Codes, SynapseError
from synapse.http.client import SimpleHttpClient
from synapse.http.server import (
- DirectServeResource,
+ DirectServeJsonResource,
respond_with_json,
respond_with_json_bytes,
- wrap_json_request_handler,
)
from synapse.http.servlet import parse_integer, parse_string
from synapse.logging.context import make_deferred_yieldable, run_in_background
@@ -58,7 +57,7 @@ OG_TAG_NAME_MAXLEN = 50
OG_TAG_VALUE_MAXLEN = 1000
-class PreviewUrlResource(DirectServeResource):
+class PreviewUrlResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo, media_storage):
@@ -108,11 +107,10 @@ class PreviewUrlResource(DirectServeResource):
self._start_expire_url_cache_data, 10 * 1000
)
- def render_OPTIONS(self, request):
+ async def _async_render_OPTIONS(self, request):
request.setHeader(b"Allow", b"OPTIONS, GET")
- return respond_with_json(request, 200, {}, send_cors=True)
+ respond_with_json(request, 200, {}, send_cors=True)
- @wrap_json_request_handler
async def _async_render_GET(self, request):
# XXX: if get_user_by_req fails, what should we do in an async render?
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 0b87220234..a83535b97b 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -16,11 +16,7 @@
import logging
-from synapse.http.server import (
- DirectServeResource,
- set_cors_headers,
- wrap_json_request_handler,
-)
+from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_integer, parse_string
from ._base import (
@@ -34,7 +30,7 @@ from ._base import (
logger = logging.getLogger(__name__)
-class ThumbnailResource(DirectServeResource):
+class ThumbnailResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo, media_storage):
@@ -45,9 +41,7 @@ class ThumbnailResource(DirectServeResource):
self.media_storage = media_storage
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname
- self.clock = hs.get_clock()
- @wrap_json_request_handler
async def _async_render_GET(self, request):
set_cors_headers(request)
server_name, media_id, _ = parse_media_id(request)
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index c234ea7421..7126997134 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -12,11 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
from io import BytesIO
-import PIL.Image as Image
+from PIL import Image as Image
logger = logging.getLogger(__name__)
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 83d005812d..3ebf7a68e6 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -15,20 +15,14 @@
import logging
-from twisted.web.server import NOT_DONE_YET
-
from synapse.api.errors import Codes, SynapseError
-from synapse.http.server import (
- DirectServeResource,
- respond_with_json,
- wrap_json_request_handler,
-)
+from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_string
logger = logging.getLogger(__name__)
-class UploadResource(DirectServeResource):
+class UploadResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo):
@@ -43,11 +37,9 @@ class UploadResource(DirectServeResource):
self.max_upload_size = hs.config.max_upload_size
self.clock = hs.get_clock()
- def render_OPTIONS(self, request):
+ async def _async_render_OPTIONS(self, request):
respond_with_json(request, 200, {}, send_cors=True)
- return NOT_DONE_YET
- @wrap_json_request_handler
async def _async_render_POST(self, request):
requester = await self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
diff --git a/synapse/rest/oidc/callback_resource.py b/synapse/rest/oidc/callback_resource.py
index c03194f001..f7a0bc4bdb 100644
--- a/synapse/rest/oidc/callback_resource.py
+++ b/synapse/rest/oidc/callback_resource.py
@@ -14,18 +14,17 @@
# limitations under the License.
import logging
-from synapse.http.server import DirectServeResource, wrap_html_request_handler
+from synapse.http.server import DirectServeHtmlResource
logger = logging.getLogger(__name__)
-class OIDCCallbackResource(DirectServeResource):
+class OIDCCallbackResource(DirectServeHtmlResource):
isLeaf = 1
def __init__(self, hs):
super().__init__()
self._oidc_handler = hs.get_oidc_handler()
- @wrap_html_request_handler
async def _async_render_GET(self, request):
- return await self._oidc_handler.handle_oidc_callback(request)
+ await self._oidc_handler.handle_oidc_callback(request)
diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py
index 75e58043b4..c10188a5d7 100644
--- a/synapse/rest/saml2/response_resource.py
+++ b/synapse/rest/saml2/response_resource.py
@@ -16,10 +16,10 @@
from twisted.python import failure
from synapse.api.errors import SynapseError
-from synapse.http.server import DirectServeResource, return_html_error
+from synapse.http.server import DirectServeHtmlResource, return_html_error
-class SAML2ResponseResource(DirectServeResource):
+class SAML2ResponseResource(DirectServeHtmlResource):
"""A Twisted web resource which handles the SAML response"""
isLeaf = 1
diff --git a/synapse/secrets.py b/synapse/secrets.py
index 0b327a0f82..5f43f81eb0 100644
--- a/synapse/secrets.py
+++ b/synapse/secrets.py
@@ -19,7 +19,6 @@ Injectable secrets module for Synapse.
See https://docs.python.org/3/library/secrets.html#module-secrets for the API
used in Python 3.6, and the API emulated in Python 2.7.
"""
-
import sys
# secrets is available since python 3.6
@@ -31,8 +30,8 @@ if sys.version_info[0:2] >= (3, 6):
else:
- import os
import binascii
+ import os
class Secrets(object):
def token_bytes(self, nbytes=32):
diff --git a/synapse/server.py b/synapse/server.py
index fe94836a2c..6acce2e23f 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -232,6 +232,8 @@ class HomeServer(object):
self._reactor = reactor
self.hostname = hostname
+ # the key we use to sign events and requests
+ self.signing_key = config.key.signing_key[0]
self.config = config
self._building = {}
self._listening_services = []
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index eac5a4e55b..f39f556c20 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -16,10 +16,12 @@
import itertools
import logging
-from typing import Any, Iterable, Optional, Tuple
+from typing import Any, Iterable, List, Optional, Tuple
from synapse.api.constants import EventTypes
+from synapse.replication.tcp.streams import BackfillStream, CachesStream
from synapse.replication.tcp.streams.events import (
+ EventsStream,
EventsStreamCurrentStateRow,
EventsStreamEventRow,
)
@@ -44,13 +46,30 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
async def get_all_updated_caches(
self, instance_name: str, last_id: int, current_id: int, limit: int
- ):
- """Fetches cache invalidation rows between the two given IDs written
- by the given instance. Returns at most `limit` rows.
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ """Get updates for caches replication stream.
+
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
"""
if last_id == current_id:
- return []
+ return [], current_id, False
def get_all_updated_caches_txn(txn):
# We purposefully don't bound by the current token, as we want to
@@ -64,17 +83,24 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
LIMIT ?
"""
txn.execute(sql, (last_id, instance_name, limit))
- return txn.fetchall()
+ updates = [(row[0], row[1:]) for row in txn]
+ limited = False
+ upto_token = current_id
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
+
+ return updates, upto_token, limited
return await self.db.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn
)
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "events":
+ if stream_name == EventsStream.NAME:
for row in rows:
self._process_event_stream_row(token, row)
- elif stream_name == "backfill":
+ elif stream_name == BackfillStream.NAME:
for row in rows:
self._invalidate_caches_for_event(
-token,
@@ -86,7 +112,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
row.relates_to,
backfilled=True,
)
- elif stream_name == "caches":
+ elif stream_name == CachesStream.NAME:
if self._cache_id_gen:
self._cache_id_gen.advance(instance_name, token)
diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index 9a1178fb39..d313b9705f 100644
--- a/synapse/storage/data_stores/main/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import List, Tuple
from canonicaljson import json
@@ -207,31 +208,46 @@ class DeviceInboxWorkerStore(SQLBaseStore):
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
)
- def get_all_new_device_messages(self, last_pos, current_pos, limit):
- """
+ async def get_all_new_device_messages(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ """Get updates for to device replication stream.
+
Args:
- last_pos(int):
- current_pos(int):
- limit(int):
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
Returns:
- A deferred list of rows from the device inbox
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
"""
- if last_pos == current_pos:
- return defer.succeed([])
+
+ if last_id == current_id:
+ return [], current_id, False
def get_all_new_device_messages_txn(txn):
# We limit like this as we might have multiple rows per stream_id, and
# we want to make sure we always get all entries for any stream_id
# we return.
- upper_pos = min(current_pos, last_pos + limit)
+ upper_pos = min(current_id, last_id + limit)
sql = (
"SELECT max(stream_id), user_id"
" FROM device_inbox"
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY user_id"
)
- txn.execute(sql, (last_pos, upper_pos))
- rows = txn.fetchall()
+ txn.execute(sql, (last_id, upper_pos))
+ updates = [(row[0], row[1:]) for row in txn]
sql = (
"SELECT max(stream_id), destination"
@@ -239,15 +255,21 @@ class DeviceInboxWorkerStore(SQLBaseStore):
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY destination"
)
- txn.execute(sql, (last_pos, upper_pos))
- rows.extend(txn)
+ txn.execute(sql, (last_id, upper_pos))
+ updates.extend((row[0], row[1:]) for row in txn)
# Order by ascending stream ordering
- rows.sort()
+ updates.sort()
- return rows
+ limited = False
+ upto_token = current_id
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
- return self.db.runInteraction(
+ return updates, upto_token, limited
+
+ return await self.db.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn
)
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 0ff0542453..343cf9a2d5 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -582,32 +582,58 @@ class DeviceWorkerStore(SQLBaseStore):
return set()
async def get_all_device_list_changes_for_remotes(
- self, from_key: int, to_key: int, limit: int,
- ) -> List[Tuple[int, str]]:
- """Return a list of `(stream_id, entity)` which is the combined list of
- changes to devices and which destinations need to be poked. Entity is
- either a user ID (starting with '@') or a remote destination.
- """
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ """Get updates for device lists replication stream.
- # This query Does The Right Thing where it'll correctly apply the
- # bounds to the inner queries.
- sql = """
- SELECT stream_id, entity FROM (
- SELECT stream_id, user_id AS entity FROM device_lists_stream
- UNION ALL
- SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
- ) AS e
- WHERE ? < stream_id AND stream_id <= ?
- LIMIT ?
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
"""
- return await self.db.execute(
+ if last_id == current_id:
+ return [], current_id, False
+
+ def _get_all_device_list_changes_for_remotes(txn):
+ # This query Does The Right Thing where it'll correctly apply the
+ # bounds to the inner queries.
+ sql = """
+ SELECT stream_id, entity FROM (
+ SELECT stream_id, user_id AS entity FROM device_lists_stream
+ UNION ALL
+ SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
+ ) AS e
+ WHERE ? < stream_id AND stream_id <= ?
+ LIMIT ?
+ """
+
+ txn.execute(sql, (last_id, current_id, limit))
+ updates = [(row[0], row[1:]) for row in txn]
+ limited = False
+ upto_token = current_id
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
+
+ return updates, upto_token, limited
+
+ return await self.db.runInteraction(
"get_all_device_list_changes_for_remotes",
- None,
- sql,
- from_key,
- to_key,
- limit,
+ _get_all_device_list_changes_for_remotes,
)
@cached(max_entries=10000)
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index 1a0842d4b0..6c3cff82e1 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -14,7 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, List
+from typing import Dict, List, Tuple
from canonicaljson import encode_canonical_json, json
@@ -479,34 +479,61 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return result
- def get_all_user_signature_changes_for_remotes(self, from_key, to_key, limit):
- """Return a list of changes from the user signature stream to notify remotes.
+ async def get_all_user_signature_changes_for_remotes(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ """Get updates for groups replication stream.
+
Note that the user signature stream represents when a user signs their
device with their user-signing key, which is not published to other
users or servers, so no `destination` is needed in the returned
list. However, this is needed to poke workers.
Args:
- from_key (int): the stream ID to start at (exclusive)
- to_key (int): the stream ID to end at (inclusive)
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
Returns:
- Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
- """
- sql = """
- SELECT stream_id, from_user_id AS user_id
- FROM user_signature_stream
- WHERE ? < stream_id AND stream_id <= ?
- ORDER BY stream_id ASC
- LIMIT ?
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
"""
- return self.db.execute(
+
+ if last_id == current_id:
+ return [], current_id, False
+
+ def _get_all_user_signature_changes_for_remotes_txn(txn):
+ sql = """
+ SELECT stream_id, from_user_id AS user_id
+ FROM user_signature_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
+ LIMIT ?
+ """
+ txn.execute(sql, (last_id, current_id, limit))
+
+ updates = [(row[0], (row[1:])) for row in txn]
+
+ limited = False
+ upto_token = current_id
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
+
+ return updates, upto_token, limited
+
+ return await self.db.runInteraction(
"get_all_user_signature_changes_for_remotes",
- None,
- sql,
- from_key,
- to_key,
- limit,
+ _get_all_user_signature_changes_for_remotes_txn,
)
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index cfd24d2f06..230fb5cd7f 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import itertools
import logging
from collections import OrderedDict, namedtuple
@@ -28,12 +27,7 @@ from prometheus_client import Counter
from twisted.internet import defer
import synapse.metrics
-from synapse.api.constants import (
- EventContentFields,
- EventTypes,
- Membership,
- RelationTypes,
-)
+from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.room_versions import RoomVersions
from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.events import EventBase # noqa: F401
@@ -48,8 +42,8 @@ from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
- from synapse.storage.data_stores.main import DataStore
from synapse.server import HomeServer
+ from synapse.storage.data_stores.main import DataStore
logger = logging.getLogger(__name__)
@@ -820,7 +814,6 @@ class PersistEventsStore:
"event_reference_hashes",
"event_search",
"event_to_state_groups",
- "local_invites",
"state_events",
"rejections",
"redactions",
@@ -1197,65 +1190,27 @@ class PersistEventsStore:
(event.state_key,),
)
- # We update the local_invites table only if the event is "current",
- # i.e., its something that has just happened. If the event is an
- # outlier it is only current if its an "out of band membership",
- # like a remote invite or a rejection of a remote invite.
- is_new_state = not backfilled and (
- not event.internal_metadata.is_outlier()
- or event.internal_metadata.is_out_of_band_membership()
- )
- is_mine = self.is_mine_id(event.state_key)
- if is_new_state and is_mine:
- if event.membership == Membership.INVITE:
- self.db.simple_insert_txn(
- txn,
- table="local_invites",
- values={
- "event_id": event.event_id,
- "invitee": event.state_key,
- "inviter": event.sender,
- "room_id": event.room_id,
- "stream_id": event.internal_metadata.stream_ordering,
- },
- )
- else:
- sql = (
- "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
- " room_id = ? AND invitee = ? AND locally_rejected is NULL"
- " AND replaced_by is NULL"
- )
-
- txn.execute(
- sql,
- (
- event.internal_metadata.stream_ordering,
- event.event_id,
- event.room_id,
- event.state_key,
- ),
- )
-
- # We also update the `local_current_membership` table with
- # latest invite info. This will usually get updated by the
- # `current_state_events` handling, unless its an outlier.
- if event.internal_metadata.is_outlier():
- # This should only happen for out of band memberships, so
- # we add a paranoia check.
- assert event.internal_metadata.is_out_of_band_membership()
-
- self.db.simple_upsert_txn(
- txn,
- table="local_current_membership",
- keyvalues={
- "room_id": event.room_id,
- "user_id": event.state_key,
- },
- values={
- "event_id": event.event_id,
- "membership": event.membership,
- },
- )
+ # We update the local_current_membership table only if the event is
+ # "current", i.e., its something that has just happened.
+ #
+ # This will usually get updated by the `current_state_events` handling,
+ # unless its an outlier, and an outlier is only "current" if it's an "out of
+ # band membership", like a remote invite or a rejection of a remote invite.
+ if (
+ self.is_mine_id(event.state_key)
+ and not backfilled
+ and event.internal_metadata.is_outlier()
+ and event.internal_metadata.is_out_of_band_membership()
+ ):
+ self.db.simple_upsert_txn(
+ txn,
+ table="local_current_membership",
+ keyvalues={"room_id": event.room_id, "user_id": event.state_key},
+ values={
+ "event_id": event.event_id,
+ "membership": event.membership,
+ },
+ )
def _handle_event_relations(self, txn, event):
"""Handles inserting relation data during peristence of events
@@ -1586,31 +1541,3 @@ class PersistEventsStore:
if not ev.internal_metadata.is_outlier()
],
)
-
- async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
- """Mark the invite has having been rejected even though we failed to
- create a leave event for it.
- """
-
- sql = (
- "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
- " room_id = ? AND invitee = ? AND locally_rejected is NULL"
- " AND replaced_by is NULL"
- )
-
- def f(txn, stream_ordering):
- txn.execute(sql, (stream_ordering, True, room_id, user_id))
-
- # We also clear this entry from `local_current_membership`.
- # Ideally we'd point to a leave event, but we don't have one, so
- # nevermind.
- self.db.simple_delete_txn(
- txn,
- table="local_current_membership",
- keyvalues={"room_id": room_id, "user_id": user_id},
- )
-
- with self._stream_id_gen.get_next() as stream_ordering:
- await self.db.runInteraction("locally_reject_invite", f, stream_ordering)
-
- return stream_ordering
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index a48c7a96ca..01cad7d4fa 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -38,6 +38,8 @@ from synapse.events.utils import prune_event
from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import BackfillStream
+from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import Database
from synapse.storage.util.id_generators import StreamIdGenerator
@@ -80,10 +82,7 @@ class EventsWorkerStore(SQLBaseStore):
# We are the process in charge of generating stream ids for events,
# so instantiate ID generators based on the database
self._stream_id_gen = StreamIdGenerator(
- db_conn,
- "events",
- "stream_ordering",
- extra_tables=[("local_invites", "stream_id")],
+ db_conn, "events", "stream_ordering",
)
self._backfill_id_gen = StreamIdGenerator(
db_conn,
@@ -113,9 +112,9 @@ class EventsWorkerStore(SQLBaseStore):
self._event_fetch_ongoing = 0
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "events":
+ if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(token)
- elif stream_name == "backfill":
+ elif stream_name == BackfillStream.NAME:
self._backfill_id_gen.advance(-token)
super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py
index fb1361f1c1..4fb9f9850c 100644
--- a/synapse/storage/data_stores/main/group_server.py
+++ b/synapse/storage/data_stores/main/group_server.py
@@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, Tuple
+
from canonicaljson import json
from twisted.internet import defer
@@ -526,13 +528,35 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_groups_changes_for_user", _get_groups_changes_for_user_txn
)
- def get_all_groups_changes(self, from_token, to_token, limit):
- from_token = int(from_token)
- has_changed = self._group_updates_stream_cache.has_any_entity_changed(
- from_token
- )
+ async def get_all_groups_changes(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ """Get updates for groups replication stream.
+
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
+ """
+
+ last_id = int(last_id)
+ has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id)
+
if not has_changed:
- return defer.succeed([])
+ return [], current_id, False
def _get_all_groups_changes_txn(txn):
sql = """
@@ -541,13 +565,21 @@ class GroupServerWorkerStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ?
LIMIT ?
"""
- txn.execute(sql, (from_token, to_token, limit))
- return [
- (stream_id, group_id, user_id, gtype, json.loads(content_json))
+ txn.execute(sql, (last_id, current_id, limit))
+ updates = [
+ (stream_id, (group_id, user_id, gtype, json.loads(content_json)))
for stream_id, group_id, user_id, gtype, content_json in txn
]
- return self.db.runInteraction(
+ limited = False
+ upto_token = current_id
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
+
+ return updates, upto_token, limited
+
+ return await self.db.runInteraction(
"get_all_groups_changes", _get_all_groups_changes_txn
)
diff --git a/synapse/storage/data_stores/main/purge_events.py b/synapse/storage/data_stores/main/purge_events.py
index a93e1ef198..6546569139 100644
--- a/synapse/storage/data_stores/main/purge_events.py
+++ b/synapse/storage/data_stores/main/purge_events.py
@@ -361,7 +361,6 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
"event_push_summary",
"pusher_throttle",
"group_summary_rooms",
- "local_invites",
"room_account_data",
"room_tags",
"local_current_membership",
diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py
index 547b9d69cb..5461016240 100644
--- a/synapse/storage/data_stores/main/pusher.py
+++ b/synapse/storage/data_stores/main/pusher.py
@@ -15,7 +15,7 @@
# limitations under the License.
import logging
-from typing import Iterable, Iterator
+from typing import Iterable, Iterator, List, Tuple
from canonicaljson import encode_canonical_json, json
@@ -98,77 +98,69 @@ class PusherWorkerStore(SQLBaseStore):
rows = yield self.db.runInteraction("get_all_pushers", get_pushers)
return rows
- def get_all_updated_pushers(self, last_id, current_id, limit):
- if last_id == current_id:
- return defer.succeed(([], []))
-
- def get_all_updated_pushers_txn(txn):
- sql = (
- "SELECT id, user_name, access_token, profile_tag, kind,"
- " app_id, app_display_name, device_display_name, pushkey, ts,"
- " lang, data"
- " FROM pushers"
- " WHERE ? < id AND id <= ?"
- " ORDER BY id ASC LIMIT ?"
- )
- txn.execute(sql, (last_id, current_id, limit))
- updated = txn.fetchall()
-
- sql = (
- "SELECT stream_id, user_id, app_id, pushkey"
- " FROM deleted_pushers"
- " WHERE ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC LIMIT ?"
- )
- txn.execute(sql, (last_id, current_id, limit))
- deleted = txn.fetchall()
+ async def get_all_updated_pushers_rows(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ """Get updates for pushers replication stream.
- return updated, deleted
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
- return self.db.runInteraction(
- "get_all_updated_pushers", get_all_updated_pushers_txn
- )
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
- def get_all_updated_pushers_rows(self, last_id, current_id, limit):
- """Get all the pushers that have changed between the given tokens.
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
- Returns:
- Deferred(list(tuple)): each tuple consists of:
- stream_id (str)
- user_id (str)
- app_id (str)
- pushkey (str)
- was_deleted (bool): whether the pusher was added/updated (False)
- or deleted (True)
+ The updates are a list of 2-tuples of stream ID and the row data
"""
if last_id == current_id:
- return defer.succeed([])
+ return [], current_id, False
def get_all_updated_pushers_rows_txn(txn):
- sql = (
- "SELECT id, user_name, app_id, pushkey"
- " FROM pushers"
- " WHERE ? < id AND id <= ?"
- " ORDER BY id ASC LIMIT ?"
- )
+ sql = """
+ SELECT id, user_name, app_id, pushkey
+ FROM pushers
+ WHERE ? < id AND id <= ?
+ ORDER BY id ASC LIMIT ?
+ """
txn.execute(sql, (last_id, current_id, limit))
- results = [list(row) + [False] for row in txn]
-
- sql = (
- "SELECT stream_id, user_id, app_id, pushkey"
- " FROM deleted_pushers"
- " WHERE ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC LIMIT ?"
- )
+ updates = [
+ (stream_id, (user_name, app_id, pushkey, False))
+ for stream_id, user_name, app_id, pushkey in txn
+ ]
+
+ sql = """
+ SELECT stream_id, user_id, app_id, pushkey
+ FROM deleted_pushers
+ WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC LIMIT ?
+ """
txn.execute(sql, (last_id, current_id, limit))
+ updates.extend(
+ (stream_id, (user_name, app_id, pushkey, True))
+ for stream_id, user_name, app_id, pushkey in txn
+ )
+
+ updates.sort() # Sort so that they're ordered by stream id
- results.extend(list(row) + [True] for row in txn)
- results.sort() # Sort so that they're ordered by stream id
+ limited = False
+ upper_bound = current_id
+ if len(updates) >= limit:
+ limited = True
+ upper_bound = updates[-1][0]
- return results
+ return updates, upper_bound, limited
- return self.db.runInteraction(
+ return await self.db.runInteraction(
"get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
)
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index 13e366536a..c473cf158f 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -803,7 +803,32 @@ class RoomWorkerStore(SQLBaseStore):
return total_media_quarantined
- def get_all_new_public_rooms(self, prev_id, current_id, limit):
+ async def get_all_new_public_rooms(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ """Get updates for public rooms replication stream.
+
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
+ """
+ if last_id == current_id:
+ return [], current_id, False
+
def get_all_new_public_rooms(txn):
sql = """
SELECT stream_id, room_id, visibility, appservice_id, network_id
@@ -813,13 +838,17 @@ class RoomWorkerStore(SQLBaseStore):
LIMIT ?
"""
- txn.execute(sql, (prev_id, current_id, limit))
- return txn.fetchall()
+ txn.execute(sql, (last_id, current_id, limit))
+ updates = [(row[0], row[1:]) for row in txn]
+ limited = False
+ upto_token = current_id
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
- if prev_id == current_id:
- return defer.succeed([])
+ return updates, upto_token, limited
- return self.db.runInteraction(
+ return await self.db.runInteraction(
"get_all_new_public_rooms", get_all_new_public_rooms
)
diff --git a/synapse/storage/data_stores/main/schema/delta/25/fts.py b/synapse/storage/data_stores/main/schema/delta/25/fts.py
index 4b2ffd35fd..ee675e71ff 100644
--- a/synapse/storage/data_stores/main/schema/delta/25/fts.py
+++ b/synapse/storage/data_stores/main/schema/delta/25/fts.py
@@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import json
import logging
-import simplejson
-
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.prepare_database import get_statements
@@ -66,7 +64,7 @@ def run_create(cur, database_engine, *args, **kwargs):
"max_stream_id_exclusive": max_stream_id + 1,
"rows_inserted": 0,
}
- progress_json = simplejson.dumps(progress)
+ progress_json = json.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/data_stores/main/schema/delta/27/ts.py b/synapse/storage/data_stores/main/schema/delta/27/ts.py
index 414f9f5aa0..b7972cfa8e 100644
--- a/synapse/storage/data_stores/main/schema/delta/27/ts.py
+++ b/synapse/storage/data_stores/main/schema/delta/27/ts.py
@@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import json
import logging
-import simplejson
-
from synapse.storage.prepare_database import get_statements
logger = logging.getLogger(__name__)
@@ -45,7 +43,7 @@ def run_create(cur, database_engine, *args, **kwargs):
"max_stream_id_exclusive": max_stream_id + 1,
"rows_inserted": 0,
}
- progress_json = simplejson.dumps(progress)
+ progress_json = json.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/data_stores/main/schema/delta/31/search_update.py b/synapse/storage/data_stores/main/schema/delta/31/search_update.py
index 7d8ca5f93f..63b757ade6 100644
--- a/synapse/storage/data_stores/main/schema/delta/31/search_update.py
+++ b/synapse/storage/data_stores/main/schema/delta/31/search_update.py
@@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import json
import logging
-import simplejson
-
from synapse.storage.engines import PostgresEngine
from synapse.storage.prepare_database import get_statements
@@ -50,7 +48,7 @@ def run_create(cur, database_engine, *args, **kwargs):
"rows_inserted": 0,
"have_added_indexes": False,
}
- progress_json = simplejson.dumps(progress)
+ progress_json = json.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/data_stores/main/schema/delta/33/event_fields.py b/synapse/storage/data_stores/main/schema/delta/33/event_fields.py
index bff1256a7b..a3e81eeac7 100644
--- a/synapse/storage/data_stores/main/schema/delta/33/event_fields.py
+++ b/synapse/storage/data_stores/main/schema/delta/33/event_fields.py
@@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import json
import logging
-import simplejson
-
from synapse.storage.prepare_database import get_statements
logger = logging.getLogger(__name__)
@@ -45,7 +43,7 @@ def run_create(cur, database_engine, *args, **kwargs):
"max_stream_id_exclusive": max_stream_id + 1,
"rows_inserted": 0,
}
- progress_json = simplejson.dumps(progress)
+ progress_json = json.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py
index f8c776be3f..290317fd94 100644
--- a/synapse/storage/data_stores/main/tags.py
+++ b/synapse/storage/data_stores/main/tags.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
+from typing import List, Tuple
from canonicaljson import json
@@ -53,18 +54,32 @@ class TagsWorkerStore(AccountDataWorkerStore):
return deferred
- @defer.inlineCallbacks
- def get_all_updated_tags(self, last_id, current_id, limit):
- """Get all the client tags that have changed on the server
+ async def get_all_updated_tags(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ """Get updates for tags replication stream.
+
Args:
- last_id(int): The position to fetch from.
- current_id(int): The position to fetch up to.
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
Returns:
- A deferred list of tuples of stream_id int, user_id string,
- room_id string, tag string and content string.
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
"""
+
if last_id == current_id:
- return []
+ return [], current_id, False
def get_all_updated_tags_txn(txn):
sql = (
@@ -76,7 +91,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
- tag_ids = yield self.db.runInteraction(
+ tag_ids = await self.db.runInteraction(
"get_all_updated_tags", get_all_updated_tags_txn
)
@@ -89,21 +104,27 @@ class TagsWorkerStore(AccountDataWorkerStore):
for tag, content in txn:
tags.append(json.dumps(tag) + ":" + content)
tag_json = "{" + ",".join(tags) + "}"
- results.append((stream_id, user_id, room_id, tag_json))
+ results.append((stream_id, (user_id, room_id, tag_json)))
return results
batch_size = 50
results = []
for i in range(0, len(tag_ids), batch_size):
- tags = yield self.db.runInteraction(
+ tags = await self.db.runInteraction(
"get_all_updated_tag_content",
get_tag_content,
tag_ids[i : i + batch_size],
)
results.extend(tags)
- return results
+ limited = False
+ upto_token = current_id
+ if len(results) >= limit:
+ upto_token = results[-1][0]
+ limited = True
+
+ return results, upto_token, limited
@defer.inlineCallbacks
def get_updated_tags(self, user_id, stream_id):
diff --git a/synapse/storage/data_stores/main/ui_auth.py b/synapse/storage/data_stores/main/ui_auth.py
index ec2f38c373..4c044b1a15 100644
--- a/synapse/storage/data_stores/main/ui_auth.py
+++ b/synapse/storage/data_stores/main/ui_auth.py
@@ -17,10 +17,10 @@ from typing import Any, Dict, Optional, Union
import attr
-import synapse.util.stringutils as stringutils
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.types import JsonDict
+from synapse.util import stringutils as stringutils
@attr.s
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 6c7d08a6f2..a31588080d 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -92,7 +92,7 @@ class PostgresEngine(BaseDatabaseEngine):
errors.append(" - 'COLLATE' is set to %r. Should be 'C'" % (collation,))
if ctype != "C":
- errors.append(" - 'CTYPE' is set to %r. Should be 'C'" % (collation,))
+ errors.append(" - 'CTYPE' is set to %r. Should be 'C'" % (ctype,))
if errors:
raise IncorrectDatabaseSetup(
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index ec894a91cb..fa46041676 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -783,9 +783,3 @@ class EventsPersistenceStorage(object):
for user_id in left_users:
await self.main_store.mark_remote_user_device_list_as_unsubscribed(user_id)
-
- async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
- """Mark the invite has having been rejected even though we failed to
- create a leave event for it.
- """
- return await self.persist_events_store.locally_reject_invite(user_id, room_id)
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index daff81c5ee..2d2b560e74 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -12,12 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from typing import Any, Iterable, Iterator, List, Tuple
from typing_extensions import Protocol
-
"""
Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
"""
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index cd56cd91ed..ca7c16ff65 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -68,13 +68,13 @@ class PaginationConfig(object):
elif from_tok:
from_tok = StreamToken.from_string(from_tok)
except Exception:
- raise SynapseError(400, "'from' paramater is invalid")
+ raise SynapseError(400, "'from' parameter is invalid")
try:
if to_tok:
to_tok = StreamToken.from_string(to_tok)
except Exception:
- raise SynapseError(400, "'to' paramater is invalid")
+ raise SynapseError(400, "'to' parameter is invalid")
limit = parse_integer(request, "limit", default=default_limit)
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index fcd2aaa9c9..5d3eddcfdc 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -68,7 +68,7 @@ class EventSources(object):
The returned token does not have the current values for fields other
than `room`, since they are not used during pagination.
- Retuns:
+ Returns:
Deferred[StreamToken]
"""
token = StreamToken(
diff --git a/synapse/types.py b/synapse/types.py
index acf60baddc..238b938064 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -29,7 +29,7 @@ from synapse.api.errors import Codes, SynapseError
if sys.version_info[:3] >= (3, 6, 0):
from typing import Collection
else:
- from typing import Sized, Iterable, Container
+ from typing import Container, Iterable, Sized
T_co = TypeVar("T_co", covariant=True)
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 60f0de70f7..c63256d3bd 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -55,7 +55,7 @@ class Clock(object):
return self._reactor.seconds()
def time_msec(self):
- """Returns the current system time in miliseconds since epoch."""
+ """Returns the current system time in milliseconds since epoch."""
return int(self.time() * 1000)
def looping_call(self, f, msec, *args, **kwargs):
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 65abf0846e..f562770922 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -352,7 +352,7 @@ class ReadWriteLock(object):
# resolved when they release the lock).
#
# Read: We know its safe to acquire a read lock when the latest writer has
- # been resolved. The new reader is appeneded to the list of latest readers.
+ # been resolved. The new reader is appended to the list of latest readers.
#
# Write: We know its safe to acquire the write lock when both the latest
# writers and readers have been resolved. The new writer replaces the latest
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 64f35fc288..9b09c08b89 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -516,7 +516,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
"""
Args:
orig (function)
- cached_method_name (str): The name of the chached method.
+ cached_method_name (str): The name of the cached method.
list_name (str): Name of the argument which is the bulk lookup list
num_args (int): number of positional arguments (excluding ``self``,
but including list_name) to use as cache keys. Defaults to all
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 45af8d3eeb..da20523b70 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -39,7 +39,7 @@ class Distributor(object):
Signals are named simply by strings.
TODO(paul): It would be nice to give signals stronger object identities,
- so we can attach metadata, docstrings, detect typoes, etc... But this
+ so we can attach metadata, docstrings, detect typos, etc... But this
model will do for today.
"""
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index 2605f3c65b..54c046b6e1 100644
--- a/synapse/util/patch_inline_callbacks.py
+++ b/synapse/util/patch_inline_callbacks.py
@@ -192,7 +192,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
result = yield d
except Exception:
# this will fish an earlier Failure out of the stack where possible, and
- # thus is preferable to passing in an exeception to the Failure
+ # thus is preferable to passing in an exception to the Failure
# constructor, since it results in less stack-mangling.
result = Failure()
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index af69587196..8794317caa 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -22,7 +22,7 @@ from synapse.api.errors import CodeMessageException
logger = logging.getLogger(__name__)
-# the intial backoff, after the first transaction fails
+# the initial backoff, after the first transaction fails
MIN_RETRY_INTERVAL = 10 * 60 * 1000
# how much we multiply the backoff by after each subsequent fail
@@ -174,7 +174,7 @@ class RetryDestinationLimiter(object):
# has been decommissioned.
# If we get a 401, then we should probably back off since they
# won't accept our requests for at least a while.
- # 429 is us being aggresively rate limited, so lets rate limit
+ # 429 is us being aggressively rate limited, so lets rate limit
# ourselves.
if exc_val.code == 404 and self.backoff_on_404:
valid_err_code = False
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
index 3ec1dfb0c2..43c2e0ac23 100644
--- a/synapse/util/threepids.py
+++ b/synapse/util/threepids.py
@@ -48,3 +48,26 @@ def check_3pid_allowed(hs, medium, address):
return True
return False
+
+
+def canonicalise_email(address: str) -> str:
+ """'Canonicalise' email address
+ Case folding of local part of email address and lowercase domain part
+ See MSC2265, https://github.com/matrix-org/matrix-doc/pull/2265
+
+ Args:
+ address: email address to be canonicalised
+ Returns:
+ The canonical form of the email address
+ Raises:
+ ValueError if the address could not be parsed.
+ """
+
+ address = address.strip()
+
+ parts = address.split("@")
+ if len(parts) != 2:
+ logger.debug("Couldn't parse email address %s", address)
+ raise ValueError("Unable to parse email address")
+
+ return parts[0].casefold() + "@" + parts[1].lower()
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 3dfd4af26c..0f042c5696 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -319,7 +319,7 @@ def filter_events_for_server(
return True
# Lets check to see if all the events have a history visibility
- # of "shared" or "world_readable". If thats the case then we don't
+ # of "shared" or "world_readable". If that's the case then we don't
# need to check membership (as we know the server is in the room).
event_to_state_ids = yield storage.state.get_state_ids_for_events(
frozenset(e.event_id for e in events),
@@ -335,7 +335,7 @@ def filter_events_for_server(
visibility_ids.add(hist)
# If we failed to find any history visibility events then the default
- # is "shared" visiblity.
+ # is "shared" visibility.
if not visibility_ids:
all_open = True
else:
|