diff --git a/changelog.d/9498.bugfix b/changelog.d/9498.bugfix
new file mode 100644
index 0000000000..dce0ad0920
--- /dev/null
+++ b/changelog.d/9498.bugfix
@@ -0,0 +1 @@
+Properly purge the event chain cover index when purging history.
diff --git a/changelog.d/9518.misc b/changelog.d/9518.misc
new file mode 100644
index 0000000000..14c7b78dd9
--- /dev/null
+++ b/changelog.d/9518.misc
@@ -0,0 +1 @@
+Fix incorrect type hints.
diff --git a/changelog.d/9521.misc b/changelog.d/9521.misc
new file mode 100644
index 0000000000..1424d9c188
--- /dev/null
+++ b/changelog.d/9521.misc
@@ -0,0 +1 @@
+Add type hints to user admin API.
\ No newline at end of file
diff --git a/changelog.d/9529.misc b/changelog.d/9529.misc
new file mode 100644
index 0000000000..b9021a26b4
--- /dev/null
+++ b/changelog.d/9529.misc
@@ -0,0 +1 @@
+Bump the versions of mypy and mypy-zope used for static type checking.
diff --git a/changelog.d/9536.bugfix b/changelog.d/9536.bugfix
new file mode 100644
index 0000000000..2ab4f315c1
--- /dev/null
+++ b/changelog.d/9536.bugfix
@@ -0,0 +1 @@
+Fix deleting pushers when using sharded pushers.
diff --git a/changelog.d/9537.bugfix b/changelog.d/9537.bugfix
new file mode 100644
index 0000000000..033ab1c939
--- /dev/null
+++ b/changelog.d/9537.bugfix
@@ -0,0 +1 @@
+Fix rare edge case that caused a background update to fail if the server had rejected an event that had duplicate auth events.
diff --git a/changelog.d/9539.feature b/changelog.d/9539.feature
new file mode 100644
index 0000000000..06cfd5d199
--- /dev/null
+++ b/changelog.d/9539.feature
@@ -0,0 +1 @@
+Add support for `X-Forwarded-Proto` header when using a reverse proxy.
diff --git a/setup.py b/setup.py
index 08ba4eb764..bbd9e7862a 100755
--- a/setup.py
+++ b/setup.py
@@ -102,7 +102,7 @@ CONDITIONAL_REQUIREMENTS["lint"] = [
"flake8",
]
-CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.790", "mypy-zope==0.2.8"]
+CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.812", "mypy-zope==0.2.11"]
# Dependencies which are exclusively required by unit test code. This is
# NOT a list of all modules that are necessary to run the unit tests.
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index d9423349e1..52314db95c 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -23,6 +23,7 @@ from typing_extensions import ContextManager
from twisted.internet import address
from twisted.web.resource import IResource
+from twisted.web.server import Request
import synapse
import synapse.events
@@ -190,7 +191,7 @@ class KeyUploadServlet(RestServlet):
self.http_client = hs.get_simple_http_client()
self.main_uri = hs.config.worker_main_http_uri
- async def on_POST(self, request, device_id):
+ async def on_POST(self, request: Request, device_id: Optional[str]):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@@ -223,10 +224,12 @@ class KeyUploadServlet(RestServlet):
header: request.requestHeaders.getRawHeaders(header, [])
for header in (b"Authorization", b"User-Agent")
}
- # Add the previous hop the the X-Forwarded-For header.
+ # Add the previous hop to the X-Forwarded-For header.
x_forwarded_for = request.requestHeaders.getRawHeaders(
b"X-Forwarded-For", []
)
+ # we use request.client here, since we want the previous hop, not the
+ # original client (as returned by request.getClientAddress()).
if isinstance(request.client, (address.IPv4Address, address.IPv6Address)):
previous_host = request.client.host.encode("ascii")
# If the header exists, add to the comma-separated list of the first
@@ -239,6 +242,14 @@ class KeyUploadServlet(RestServlet):
x_forwarded_for = [previous_host]
headers[b"X-Forwarded-For"] = x_forwarded_for
+ # Replicate the original X-Forwarded-Proto header. Note that
+ # XForwardedForRequest overrides isSecure() to give us the original protocol
+ # used by the client, as opposed to the protocol used by our upstream proxy
+ # - which is what we want here.
+ headers[b"X-Forwarded-Proto"] = [
+ b"https" if request.isSecure() else b"http"
+ ]
+
try:
result = await self.http_client.post_json_get_json(
self.main_uri + request.uri.decode("ascii"), body, headers=headers
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 2e83fa6773..b07aa59c08 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
import urllib.parse
-from typing import List, Optional
+from typing import Any, Generator, List, Optional
from netaddr import AddrFormatError, IPAddress, IPSet
from zope.interface import implementer
@@ -116,7 +116,7 @@ class MatrixFederationAgent:
uri: bytes,
headers: Optional[Headers] = None,
bodyProducer: Optional[IBodyProducer] = None,
- ) -> defer.Deferred:
+ ) -> Generator[defer.Deferred, Any, defer.Deferred]:
"""
Args:
method: HTTP method: GET/POST/etc
@@ -177,17 +177,17 @@ class MatrixFederationAgent:
# We need to make sure the host header is set to the netloc of the
# server and that a user-agent is provided.
if headers is None:
- headers = Headers()
+ request_headers = Headers()
else:
- headers = headers.copy()
+ request_headers = headers.copy()
- if not headers.hasHeader(b"host"):
- headers.addRawHeader(b"host", parsed_uri.netloc)
- if not headers.hasHeader(b"user-agent"):
- headers.addRawHeader(b"user-agent", self.user_agent)
+ if not request_headers.hasHeader(b"host"):
+ request_headers.addRawHeader(b"host", parsed_uri.netloc)
+ if not request_headers.hasHeader(b"user-agent"):
+ request_headers.addRawHeader(b"user-agent", self.user_agent)
res = yield make_deferred_yieldable(
- self._agent.request(method, uri, headers, bodyProducer)
+ self._agent.request(method, uri, request_headers, bodyProducer)
)
return res
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index cde42e9f5e..0f107714ea 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -1049,14 +1049,14 @@ def check_content_type_is_json(headers: Headers) -> None:
RequestSendFailed: if the Content-Type header is missing or isn't JSON
"""
- c_type = headers.getRawHeaders(b"Content-Type")
- if c_type is None:
+ content_type_headers = headers.getRawHeaders(b"Content-Type")
+ if content_type_headers is None:
raise RequestSendFailed(
RuntimeError("No Content-Type header received from remote server"),
can_retry=False,
)
- c_type = c_type[0].decode("ascii") # only the first header
+ c_type = content_type_headers[0].decode("ascii") # only the first header
val, options = cgi.parse_header(c_type)
if val != "application/json":
raise RequestSendFailed(
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 845db9b78d..fa89260850 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -21,6 +21,7 @@ import logging
import types
import urllib
from http import HTTPStatus
+from inspect import isawaitable
from io import BytesIO
from typing import (
Any,
@@ -30,6 +31,7 @@ from typing import (
Iterable,
Iterator,
List,
+ Optional,
Pattern,
Tuple,
Union,
@@ -79,10 +81,12 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
"""Sends a JSON error response to clients."""
if f.check(SynapseError):
- error_code = f.value.code
- error_dict = f.value.error_dict()
+ # mypy doesn't understand that f.check asserts the type.
+ exc = f.value # type: SynapseError # type: ignore
+ error_code = exc.code
+ error_dict = exc.error_dict()
- logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg)
+ logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg)
else:
error_code = 500
error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN}
@@ -91,7 +95,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
"Failed handle request via %r: %r",
request.request_metrics.name,
request,
- exc_info=(f.type, f.value, f.getTracebackObject()),
+ exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
)
# Only respond with an error response if we haven't already started writing,
@@ -128,7 +132,8 @@ def return_html_error(
`{msg}` placeholders), or a jinja2 template
"""
if f.check(CodeMessageException):
- cme = f.value
+ # mypy doesn't understand that f.check asserts the type.
+ cme = f.value # type: CodeMessageException # type: ignore
code = cme.code
msg = cme.msg
@@ -142,7 +147,7 @@ def return_html_error(
logger.error(
"Failed handle request %r",
request,
- exc_info=(f.type, f.value, f.getTracebackObject()),
+ exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
)
else:
code = HTTPStatus.INTERNAL_SERVER_ERROR
@@ -151,7 +156,7 @@ def return_html_error(
logger.error(
"Failed handle request %r",
request,
- exc_info=(f.type, f.value, f.getTracebackObject()),
+ exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
)
if isinstance(error_template, str):
@@ -278,7 +283,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
raw_callback_return = method_handler(request)
# Is it synchronous? We'll allow this for now.
- if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
+ if isawaitable(raw_callback_return):
callback_return = await raw_callback_return
else:
callback_return = raw_callback_return # type: ignore
@@ -399,8 +404,10 @@ class JsonResource(DirectServeJsonResource):
A tuple of the callback to use, the name of the servlet, and the
key word arguments to pass to the callback
"""
+ # At this point the path must be bytes.
+ request_path_bytes = request.path # type: bytes # type: ignore
+ request_path = request_path_bytes.decode("ascii")
# Treat HEAD requests as GET requests.
- request_path = request.path.decode("ascii")
request_method = request.method
if request_method == b"HEAD":
request_method = b"GET"
@@ -551,7 +558,7 @@ class _ByteProducer:
request: Request,
iterator: Iterator[bytes],
):
- self._request = request
+ self._request = request # type: Optional[Request]
self._iterator = iterator
self._paused = False
@@ -563,7 +570,7 @@ class _ByteProducer:
"""
Send a list of bytes as a chunk of a response.
"""
- if not data:
+ if not data or not self._request:
return
self._request.write(b"".join(data))
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 30153237e3..47754aff43 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -14,7 +14,7 @@
import contextlib
import logging
import time
-from typing import Optional, Union
+from typing import Optional, Type, Union
import attr
from zope.interface import implementer
@@ -57,7 +57,7 @@ class SynapseRequest(Request):
def __init__(self, channel, *args, **kw):
Request.__init__(self, channel, *args, **kw)
- self.site = channel.site
+ self.site = channel.site # type: SynapseSite
self._channel = channel # this is used by the tests
self.start_time = 0.0
@@ -96,25 +96,34 @@ class SynapseRequest(Request):
def get_request_id(self):
return "%s-%i" % (self.get_method(), self.request_seq)
- def get_redacted_uri(self):
- uri = self.uri
+ def get_redacted_uri(self) -> str:
+ """Gets the redacted URI associated with the request (or placeholder if the URI
+ has not yet been received).
+
+ Note: This is necessary as the placeholder value in twisted is str
+ rather than bytes, so we need to sanitise `self.uri`.
+
+ Returns:
+ The redacted URI as a string.
+ """
+ uri = self.uri # type: Union[bytes, str]
if isinstance(uri, bytes):
- uri = self.uri.decode("ascii", errors="replace")
+ uri = uri.decode("ascii", errors="replace")
return redact_uri(uri)
- def get_method(self):
- """Gets the method associated with the request (or placeholder if not
- method has yet been received).
+ def get_method(self) -> str:
+ """Gets the method associated with the request (or placeholder if method
+ has not yet been received).
Note: This is necessary as the placeholder value in twisted is str
rather than bytes, so we need to sanitise `self.method`.
Returns:
- str
+ The request method as a string.
"""
- method = self.method
+ method = self.method # type: Union[bytes, str]
if isinstance(method, bytes):
- method = self.method.decode("ascii")
+ return self.method.decode("ascii")
return method
def render(self, resrc):
@@ -432,7 +441,9 @@ class SynapseSite(Site):
assert config.http_options is not None
proxied = config.http_options.x_forwarded
- self.requestFactory = XForwardedForRequest if proxied else SynapseRequest
+ self.requestFactory = (
+ XForwardedForRequest if proxied else SynapseRequest
+ ) # type: Type[Request]
self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string.encode("ascii")
diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py
index f8e9112b56..174ca7be5a 100644
--- a/synapse/logging/_remote.py
+++ b/synapse/logging/_remote.py
@@ -32,7 +32,7 @@ from twisted.internet.endpoints import (
TCP4ClientEndpoint,
TCP6ClientEndpoint,
)
-from twisted.internet.interfaces import IPushProducer, ITransport
+from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint, ITransport
from twisted.internet.protocol import Factory, Protocol
from twisted.python.failure import Failure
@@ -121,7 +121,9 @@ class RemoteHandler(logging.Handler):
try:
ip = ip_address(self.host)
if isinstance(ip, IPv4Address):
- endpoint = TCP4ClientEndpoint(_reactor, self.host, self.port)
+ endpoint = TCP4ClientEndpoint(
+ _reactor, self.host, self.port
+ ) # type: IStreamClientEndpoint
elif isinstance(ip, IPv6Address):
endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port)
else:
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index a8cb49d5b4..3b499efc07 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -527,7 +527,7 @@ class ReactorLastSeenMetric:
REGISTRY.register(ReactorLastSeenMetric())
-def runUntilCurrentTimer(func):
+def runUntilCurrentTimer(reactor, func):
@functools.wraps(func)
def f(*args, **kwargs):
now = reactor.seconds()
@@ -590,13 +590,14 @@ def runUntilCurrentTimer(func):
try:
# Ensure the reactor has all the attributes we expect
- reactor.runUntilCurrent
- reactor._newTimedCalls
- reactor.threadCallQueue
+ reactor.seconds # type: ignore
+ reactor.runUntilCurrent # type: ignore
+ reactor._newTimedCalls # type: ignore
+ reactor.threadCallQueue # type: ignore
# runUntilCurrent is called when we have pending calls. It is called once
# per iteratation after fd polling.
- reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent)
+ reactor.runUntilCurrent = runUntilCurrentTimer(reactor, reactor.runUntilCurrent) # type: ignore
# We manually run the GC each reactor tick so that we can get some metrics
# about time spent doing GC,
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 2e3b311c4a..db2d400b7e 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Iterable, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Generator, Iterable, Optional, Tuple
from twisted.internet import defer
@@ -307,7 +307,7 @@ class ModuleApi:
@defer.inlineCallbacks
def get_state_events_in_room(
self, room_id: str, types: Iterable[Tuple[str, Optional[str]]]
- ) -> defer.Deferred:
+ ) -> Generator[defer.Deferred, Any, defer.Deferred]:
"""Gets current state events for the given room.
(This is exposed for compatibility with the old SpamCheckerApi. We should
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index f4d7e199e9..eb6de8ba72 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -15,11 +15,12 @@
# limitations under the License.
import logging
import urllib.parse
-from typing import TYPE_CHECKING, Any, Dict, Iterable, Union
+from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union
from prometheus_client import Counter
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
+from twisted.internet.interfaces import IDelayedCall
from synapse.api.constants import EventTypes
from synapse.events import EventBase
@@ -71,7 +72,7 @@ class HttpPusher(Pusher):
self.data = pusher_config.data
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.failing_since = pusher_config.failing_since
- self.timed_call = None
+ self.timed_call = None # type: Optional[IDelayedCall]
self._is_processing = False
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
self._pusherpool = hs.get_pusherpool()
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 2618eb1e53..3455839d67 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -108,9 +108,7 @@ class ReplicationDataHandler:
# Map from stream to list of deferreds waiting for the stream to
# arrive at a particular position. The lists are sorted by stream position.
- self._streams_to_waiters = (
- {}
- ) # type: Dict[str, List[Tuple[int, Deferred[None]]]]
+ self._streams_to_waiters = {} # type: Dict[str, List[Tuple[int, Deferred]]]
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 38809b5b7c..f45e7a8c89 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -502,7 +502,7 @@ class AccountDataStream(Stream):
"""Global or per room account data was changed"""
AccountDataStreamRow = namedtuple(
- "AccountDataStream",
+ "AccountDataStreamRow",
("user_id", "room_id", "data_type"), # str # Optional[str] # str
)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 9c701c7348..267a993430 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -16,7 +16,7 @@ import hashlib
import hmac
import logging
from http import HTTPStatus
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError
@@ -47,13 +47,15 @@ logger = logging.getLogger(__name__)
class UsersRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
- async def on_GET(self, request, user_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, List[JsonDict]]:
target_user = UserID.from_string(user_id)
await assert_requester_is_admin(self.auth, request)
@@ -153,7 +155,7 @@ class UserRestServletV2(RestServlet):
otherwise an error.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
@@ -165,7 +167,9 @@ class UserRestServletV2(RestServlet):
self.registration_handler = hs.get_registration_handler()
self.pusher_pool = hs.get_pusherpool()
- async def on_GET(self, request, user_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
@@ -179,7 +183,9 @@ class UserRestServletV2(RestServlet):
return 200, ret
- async def on_PUT(self, request, user_id):
+ async def on_PUT(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -273,6 +279,8 @@ class UserRestServletV2(RestServlet):
)
user = await self.admin_handler.get_user(target_user)
+ assert user is not None
+
return 200, user
else: # create user
@@ -330,9 +338,10 @@ class UserRestServletV2(RestServlet):
target_user, requester, body["avatar_url"], True
)
- ret = await self.admin_handler.get_user(target_user)
+ user = await self.admin_handler.get_user(target_user)
+ assert user is not None
- return 201, ret
+ return 201, user
class UserRegisterServlet(RestServlet):
@@ -346,10 +355,10 @@ class UserRegisterServlet(RestServlet):
PATTERNS = admin_patterns("/register")
NONCE_TIMEOUT = 60
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.auth_handler = hs.get_auth_handler()
self.reactor = hs.get_reactor()
- self.nonces = {}
+ self.nonces = {} # type: Dict[str, int]
self.hs = hs
def _clear_old_nonces(self):
@@ -362,7 +371,7 @@ class UserRegisterServlet(RestServlet):
if now - v > self.NONCE_TIMEOUT:
del self.nonces[k]
- def on_GET(self, request):
+ def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"""
Generate a new nonce.
"""
@@ -372,7 +381,7 @@ class UserRegisterServlet(RestServlet):
self.nonces[nonce] = int(self.reactor.seconds())
return 200, {"nonce": nonce}
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
self._clear_old_nonces()
if not self.hs.config.registration_shared_secret:
@@ -478,12 +487,14 @@ class WhoisRestServlet(RestServlet):
client_patterns("/admin" + path_regex, v1=True)
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
- async def on_GET(self, request, user_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request)
auth_user = requester.user
@@ -508,7 +519,9 @@ class DeactivateAccountRestServlet(RestServlet):
self.is_mine = hs.is_mine
self.store = hs.get_datastore()
- async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]:
+ async def on_POST(
+ self, request: SynapseRequest, target_user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -550,7 +563,7 @@ class AccountValidityRenewServlet(RestServlet):
self.account_activity_handler = hs.get_account_validity_handler()
self.auth = hs.get_auth()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)
@@ -584,14 +597,16 @@ class ResetPasswordRestServlet(RestServlet):
PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self._set_password_handler = hs.get_set_password_handler()
- async def on_POST(self, request, target_user_id):
+ async def on_POST(
+ self, request: SynapseRequest, target_user_id: str
+ ) -> Tuple[int, JsonDict]:
"""Post request to allow an administrator reset password for a user.
This needs user to have administrator access in Synapse.
"""
@@ -626,12 +641,14 @@ class SearchUsersRestServlet(RestServlet):
PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_GET(self, request, target_user_id):
+ async def on_GET(
+ self, request: SynapseRequest, target_user_id: str
+ ) -> Tuple[int, Optional[List[JsonDict]]]:
"""Get request to search user table for specific users according to
search term.
This needs user to have a administrator access in Synapse.
@@ -682,12 +699,14 @@ class UserAdminServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_GET(self, request, user_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
@@ -699,7 +718,9 @@ class UserAdminServlet(RestServlet):
return 200, {"admin": is_admin}
- async def on_PUT(self, request, user_id):
+ async def on_PUT(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
auth_user = requester.user
@@ -730,12 +751,14 @@ class UserMembershipRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request, user_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
room_ids = await self.store.get_rooms_for_user(user_id)
@@ -758,7 +781,7 @@ class PushersRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine
self.store = hs.get_datastore()
self.auth = hs.get_auth()
@@ -799,7 +822,7 @@ class UserMediaRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -891,7 +914,9 @@ class UserTokenRestServlet(RestServlet):
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
- async def on_POST(self, request, user_id):
+ async def on_POST(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
auth_user = requester.user
@@ -943,7 +968,9 @@ class ShadowBanRestServlet(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_POST(self, request, user_id):
+ async def on_POST(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id):
diff --git a/synapse/server.py b/synapse/server.py
index 1d4370e0ba..afd7cd72e7 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -38,6 +38,7 @@ from typing import (
import twisted.internet.base
import twisted.internet.tcp
+from twisted.internet import defer
from twisted.mail.smtp import sendmail
from twisted.web.iweb import IPolicyForHTTPS
@@ -403,7 +404,7 @@ class HomeServer(metaclass=abc.ABCMeta):
return RoomShutdownHandler(self)
@cache_in_self
- def get_sendmail(self) -> sendmail:
+ def get_sendmail(self) -> Callable[..., defer.Deferred]:
return sendmail
@cache_in_self
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 70b49854cf..1d44c3aa2c 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -16,7 +16,7 @@
# limitations under the License.
import logging
-from typing import Any, Dict, List, Optional, Tuple
+from typing import List, Optional, Tuple
from synapse.api.constants import PresenceState
from synapse.config.homeserver import HomeServerConfig
@@ -27,7 +27,7 @@ from synapse.storage.util.id_generators import (
MultiWriterIdGenerator,
StreamIdGenerator,
)
-from synapse.types import get_domain_from_id
+from synapse.types import JsonDict, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
from .account_data import AccountDataStore
@@ -264,7 +264,7 @@ class DataStore(
return [UserPresenceState(**row) for row in rows]
- async def get_users(self) -> List[Dict[str, Any]]:
+ async def get_users(self) -> List[JsonDict]:
"""Function to retrieve a list of users in users table.
Returns:
@@ -292,7 +292,7 @@ class DataStore(
name: Optional[str] = None,
guests: bool = True,
deactivated: bool = False,
- ) -> Tuple[List[Dict[str, Any]], int]:
+ ) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of users from
users list. This will return a json list of users and the
total number of users matching the filter criteria.
@@ -353,7 +353,7 @@ class DataStore(
"get_users_paginate_txn", get_users_paginate_txn
)
- async def search_users(self, term: str) -> Optional[List[Dict[str, Any]]]:
+ async def search_users(self, term: str) -> Optional[List[JsonDict]]:
"""Function to search users list for one or more users with
the matched term.
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index c1626ccf28..cb6b1f8a0c 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -696,7 +696,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
)
if not has_event_auth:
- for auth_id in event.auth_event_ids():
+ # Old, dodgy, events may have duplicate auth events, which we
+ # need to deduplicate as we have a unique constraint.
+ for auth_id in set(event.auth_event_ids()):
auth_events.append(
{
"room_id": event.room_id,
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 274f8de595..4f3d192562 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -139,7 +139,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
start: int,
limit: int,
user_id: str,
- order_by: MediaSortOrder = MediaSortOrder.CREATED_TS.value,
+ order_by: str = MediaSortOrder.CREATED_TS.value,
direction: str = "f",
) -> Tuple[List[Dict[str, Any]], int]:
"""Get a paginated list of metadata for a local piece of media
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index ecfc9f20b1..0836e4af49 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -28,7 +28,10 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
async def purge_history(
self, room_id: str, token: str, delete_local_events: bool
) -> Set[int]:
- """Deletes room history before a certain point
+ """Deletes room history before a certain point.
+
+ Note that only a single purge can occur at once, this is guaranteed via
+ a higher level (in the PaginationHandler).
Args:
room_id:
@@ -52,7 +55,9 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
delete_local_events,
)
- def _purge_history_txn(self, txn, room_id, token, delete_local_events):
+ def _purge_history_txn(
+ self, txn, room_id: str, token: RoomStreamToken, delete_local_events: bool
+ ) -> Set[int]:
# Tables that should be pruned:
# event_auth
# event_backward_extremities
@@ -103,7 +108,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
if max_depth < token.topological:
# We need to ensure we don't delete all the events from the database
# otherwise we wouldn't be able to send any events (due to not
- # having any backwards extremeties)
+ # having any backwards extremities)
raise SynapseError(
400, "topological_ordering is greater than forward extremeties"
)
@@ -154,7 +159,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
logger.info("[purge] Finding new backward extremities")
- # We calculate the new entries for the backward extremeties by finding
+ # We calculate the new entries for the backward extremities by finding
# events to be purged that are pointed to by events we're not going to
# purge.
txn.execute(
@@ -296,7 +301,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
"purge_room", self._purge_room_txn, room_id
)
- def _purge_room_txn(self, txn, room_id):
+ def _purge_room_txn(self, txn, room_id: str) -> List[int]:
# First we fetch all the state groups that should be deleted, before
# we delete that information.
txn.execute(
@@ -310,6 +315,31 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
state_groups = [row[0] for row in txn]
+ # Get all the auth chains that are referenced by events that are to be
+ # deleted.
+ txn.execute(
+ """
+ SELECT chain_id, sequence_number FROM events
+ LEFT JOIN event_auth_chains USING (event_id)
+ WHERE room_id = ?
+ """,
+ (room_id,),
+ )
+ referenced_chain_id_tuples = list(txn)
+
+ logger.info("[purge] removing events from event_auth_chain_links")
+ txn.executemany(
+ """
+ DELETE FROM event_auth_chain_links WHERE
+ (origin_chain_id = ? AND origin_sequence_number = ?) OR
+ (target_chain_id = ? AND target_sequence_number = ?)
+ """,
+ (
+ (chain_id, seq_num, chain_id, seq_num)
+ for (chain_id, seq_num) in referenced_chain_id_tuples
+ ),
+ )
+
# Now we delete tables which lack an index on room_id but have one on event_id
for table in (
"event_auth",
@@ -319,6 +349,8 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
"event_reference_hashes",
"event_relations",
"event_to_state_groups",
+ "event_auth_chains",
+ "event_auth_chain_to_calculate",
"redactions",
"rejections",
"state_events",
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 6b608ebc9b..85f1ebac98 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -44,6 +44,11 @@ class PusherWorkerStore(SQLBaseStore):
self._remove_deactivated_pushers,
)
+ self.db_pool.updates.register_background_update_handler(
+ "remove_stale_pushers",
+ self._remove_stale_pushers,
+ )
+
def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
"""JSON-decode the data in the rows returned from the `pushers` table
@@ -337,6 +342,53 @@ class PusherWorkerStore(SQLBaseStore):
return number_deleted
+ async def _remove_stale_pushers(self, progress: dict, batch_size: int) -> int:
+ """A background update that deletes all pushers for logged out devices.
+
+ Note that we don't proacively tell the pusherpool that we've deleted
+ these (just because its a bit off a faff to do from here), but they will
+ get cleaned up at the next restart
+ """
+
+ last_pusher = progress.get("last_pusher", 0)
+
+ def _delete_pushers(txn) -> int:
+
+ sql = """
+ SELECT p.id, access_token FROM pushers AS p
+ LEFT JOIN access_tokens AS a ON (p.access_token = a.id)
+ WHERE p.id > ?
+ ORDER BY p.id ASC
+ LIMIT ?
+ """
+
+ txn.execute(sql, (last_pusher, batch_size))
+ pushers = [(row[0], row[1]) for row in txn]
+
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="pushers",
+ column="id",
+ iterable=(pusher_id for pusher_id, token in pushers if token is None),
+ keyvalues={},
+ )
+
+ if pushers:
+ self.db_pool.updates._background_update_progress_txn(
+ txn, "remove_stale_pushers", {"last_pusher": pushers[-1][0]}
+ )
+
+ return len(pushers)
+
+ number_deleted = await self.db_pool.runInteraction(
+ "_remove_stale_pushers", _delete_pushers
+ )
+
+ if number_deleted < batch_size:
+ await self.db_pool.updates._end_background_update("remove_stale_pushers")
+
+ return number_deleted
+
class PusherStore(PusherWorkerStore):
def get_pushers_stream_token(self) -> int:
diff --git a/synapse/storage/databases/main/schema/delta/59/08delete_stale_pushers.sql b/synapse/storage/databases/main/schema/delta/59/08delete_stale_pushers.sql
index 2442eea6bc..85196db288 100644
--- a/synapse/storage/databases/main/schema/delta/59/08delete_stale_pushers.sql
+++ b/synapse/storage/databases/main/schema/delta/59/08delete_stale_pushers.sql
@@ -16,4 +16,5 @@
-- Delete all pushers associated with deleted devices. This is to clear up after
-- a bug where they weren't correctly deleted when using workers.
-DELETE FROM pushers WHERE access_token NOT IN (SELECT id FROM access_tokens);
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5908, 'remove_stale_pushers', '{}');
diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
index 3c4908865f..4dcd848c59 100644
--- a/synapse/storage/purge_events.py
+++ b/synapse/storage/purge_events.py
@@ -73,9 +73,6 @@ class PurgeEventsStorage:
Returns:
The set of state groups that can be deleted.
"""
- # Graph of state group -> previous group
- graph = {}
-
# Set of events that we have found to be referenced by events
referenced_groups = set()
@@ -111,8 +108,6 @@ class PurgeEventsStorage:
next_to_search |= prevs
state_groups_seen |= prevs
- graph.update(edges)
-
to_delete = state_groups_seen - referenced_groups
return to_delete
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index f152f63321..d2ff4da6b9 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -25,7 +25,7 @@ RoomsForUser = namedtuple(
)
GetRoomsForUserWithStreamOrdering = namedtuple(
- "_GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos")
+ "GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos")
)
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 744d8d0941..20af3285bd 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -522,7 +522,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
shorthand=False,
)
self.assertEqual(channel.code, 302, channel.result)
- cas_uri = channel.headers.getRawHeaders("Location")[0]
+ location_headers = channel.headers.getRawHeaders("Location")
+ assert location_headers
+ cas_uri = location_headers[0]
cas_uri_path, cas_uri_query = cas_uri.split("?", 1)
# it should redirect us to the login page of the cas server
@@ -545,7 +547,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ "&idp=saml",
)
self.assertEqual(channel.code, 302, channel.result)
- saml_uri = channel.headers.getRawHeaders("Location")[0]
+ location_headers = channel.headers.getRawHeaders("Location")
+ assert location_headers
+ saml_uri = location_headers[0]
saml_uri_path, saml_uri_query = saml_uri.split("?", 1)
# it should redirect us to the login page of the SAML server
@@ -567,17 +571,21 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ "&idp=oidc",
)
self.assertEqual(channel.code, 302, channel.result)
- oidc_uri = channel.headers.getRawHeaders("Location")[0]
+ location_headers = channel.headers.getRawHeaders("Location")
+ assert location_headers
+ oidc_uri = location_headers[0]
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
# it should redirect us to the auth page of the OIDC server
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
# ... and should have set a cookie including the redirect url
- cookies = dict(
- h.split(";")[0].split("=", maxsplit=1)
- for h in channel.headers.getRawHeaders("Set-Cookie")
- )
+ cookie_headers = channel.headers.getRawHeaders("Set-Cookie")
+ assert cookie_headers
+ cookies = {} # type: Dict[str, str]
+ for h in cookie_headers:
+ key, value = h.split(";")[0].split("=", maxsplit=1)
+ cookies[key] = value
oidc_session_cookie = cookies["oidc_session"]
macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie)
@@ -590,9 +598,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
# that should serve a confirmation page
self.assertEqual(channel.code, 200, channel.result)
- self.assertTrue(
- channel.headers.getRawHeaders("Content-Type")[-1].startswith("text/html")
- )
+ content_type_headers = channel.headers.getRawHeaders("Content-Type")
+ assert content_type_headers
+ self.assertTrue(content_type_headers[-1].startswith("text/html"))
p = TestHtmlParser()
p.feed(channel.text_body)
p.close()
@@ -806,6 +814,7 @@ class CASTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 302)
location_headers = channel.headers.getRawHeaders("Location")
+ assert location_headers
self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url)
@override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}})
@@ -1248,7 +1257,9 @@ class UsernamePickerTestCase(HomeserverTestCase):
# that should redirect to the username picker
self.assertEqual(channel.code, 302, channel.result)
- picker_url = channel.headers.getRawHeaders("Location")[0]
+ location_headers = channel.headers.getRawHeaders("Location")
+ assert location_headers
+ picker_url = location_headers[0]
self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details")
# ... with a username_mapping_session cookie
@@ -1291,6 +1302,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
)
self.assertEqual(chan.code, 302, chan.result)
location_headers = chan.headers.getRawHeaders("Location")
+ assert location_headers
# send a request to the completion page, which should 302 to the client redirectUrl
chan = self.make_request(
@@ -1300,6 +1312,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
)
self.assertEqual(chan.code, 302, chan.result)
location_headers = chan.headers.getRawHeaders("Location")
+ assert location_headers
# ensure that the returned location matches the requested redirect URL
path, query = location_headers[0].split("?", 1)
|