diff --git a/CHANGES.md b/CHANGES.md
index d859baa9ff..92e29983b9 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,3 +1,17 @@
+For the next release
+====================
+
+Removal warning
+---------------
+
+Some older clients used a
+[disallowed character](https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-register-email-requesttoken)
+(`:`) in the `client_secret` parameter of various endpoints. The incorrect
+behaviour was allowed for backwards compatibility, but is now being removed
+from Synapse as most users have updated their client. Further context can be
+found at [\#6766](https://github.com/matrix-org/synapse/issues/6766).
+
+
Synapse 1.19.1 (2020-08-27)
===========================
diff --git a/changelog.d/7757.misc b/changelog.d/7757.misc
new file mode 100644
index 0000000000..091f40382e
--- /dev/null
+++ b/changelog.d/7757.misc
@@ -0,0 +1 @@
+Reduce run times of some unit tests by advancing the reactor a fewer number of times.
\ No newline at end of file
diff --git a/changelog.d/8130.misc b/changelog.d/8130.misc
new file mode 100644
index 0000000000..7944c09ade
--- /dev/null
+++ b/changelog.d/8130.misc
@@ -0,0 +1 @@
+Update the test federation client to handle streaming responses.
diff --git a/changelog.d/8144.docker b/changelog.d/8144.docker
new file mode 100644
index 0000000000..9bb5881fa8
--- /dev/null
+++ b/changelog.d/8144.docker
@@ -0,0 +1 @@
+Fix builds of the Docker image on non-x86 platforms.
diff --git a/changelog.d/8157.feature b/changelog.d/8157.feature
new file mode 100644
index 0000000000..813e6d0903
--- /dev/null
+++ b/changelog.d/8157.feature
@@ -0,0 +1 @@
+Add support for shadow-banning users (ignoring any message send requests).
diff --git a/changelog.d/8162.misc b/changelog.d/8162.misc
new file mode 100644
index 0000000000..e26764dea1
--- /dev/null
+++ b/changelog.d/8162.misc
@@ -0,0 +1 @@
+ Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8171.misc b/changelog.d/8171.misc
new file mode 100644
index 0000000000..cafbf23d83
--- /dev/null
+++ b/changelog.d/8171.misc
@@ -0,0 +1 @@
+Make `SlavedIdTracker.advance` have the same interface as `MultiWriterIDGenerator`.
diff --git a/changelog.d/8174.misc b/changelog.d/8174.misc
new file mode 100644
index 0000000000..a39e9eab46
--- /dev/null
+++ b/changelog.d/8174.misc
@@ -0,0 +1 @@
+Remove unused `is_guest` parameter from, and add safeguard to, `MessageHandler.get_room_data`.
\ No newline at end of file
diff --git a/changelog.d/8175.misc b/changelog.d/8175.misc
new file mode 100644
index 0000000000..28af294dcf
--- /dev/null
+++ b/changelog.d/8175.misc
@@ -0,0 +1 @@
+Standardize the mypy configuration.
diff --git a/changelog.d/8176.feature b/changelog.d/8176.feature
new file mode 100644
index 0000000000..813e6d0903
--- /dev/null
+++ b/changelog.d/8176.feature
@@ -0,0 +1 @@
+Add support for shadow-banning users (ignoring any message send requests).
diff --git a/changelog.d/8181.misc b/changelog.d/8181.misc
new file mode 100644
index 0000000000..a39e9eab46
--- /dev/null
+++ b/changelog.d/8181.misc
@@ -0,0 +1 @@
+Remove unused `is_guest` parameter from, and add safeguard to, `MessageHandler.get_room_data`.
\ No newline at end of file
diff --git a/debian/changelog b/debian/changelog
index 6676706dea..bde3b636ee 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,3 +1,9 @@
+matrix-synapse-py3 (1.19.0ubuntu1) UNRELEASED; urgency=medium
+
+ * Use Type=notify in systemd service
+
+ -- Dexter Chua <dec41@srcf.net> Wed, 26 Aug 2020 12:41:36 +0000
+
matrix-synapse-py3 (1.19.1) stable; urgency=medium
* New synapse release 1.19.1.
diff --git a/debian/matrix-synapse.service b/debian/matrix-synapse.service
index b0a8d72e6d..553babf549 100644
--- a/debian/matrix-synapse.service
+++ b/debian/matrix-synapse.service
@@ -2,7 +2,7 @@
Description=Synapse Matrix homeserver
[Service]
-Type=simple
+Type=notify
User=matrix-synapse
WorkingDirectory=/var/lib/matrix-synapse
EnvironmentFile=/etc/default/matrix-synapse
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 432d56a8ee..27512f8600 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -19,11 +19,16 @@ ARG PYTHON_VERSION=3.7
FROM docker.io/python:${PYTHON_VERSION}-slim as builder
# install the OS build deps
-
-
RUN apt-get update && apt-get install -y \
build-essential \
+ libffi-dev \
+ libjpeg-dev \
libpq-dev \
+ libssl-dev \
+ libwebp-dev \
+ libxml++2.6-dev \
+ libxslt1-dev \
+ zlib1g-dev \
&& rm -rf /var/lib/apt/lists/*
# Build dependencies that are not available as wheels, to speed up rebuilds
@@ -56,9 +61,11 @@ FROM docker.io/python:${PYTHON_VERSION}-slim
RUN apt-get update && apt-get install -y \
curl \
+ gosu \
+ libjpeg62-turbo \
libpq5 \
+ libwebp6 \
xmlsec1 \
- gosu \
&& rm -rf /var/lib/apt/lists/*
COPY --from=builder /install /usr/local
diff --git a/mypy.ini b/mypy.ini
index c69cb5dc40..4213e31b03 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -6,6 +6,55 @@ check_untyped_defs = True
show_error_codes = True
show_traceback = True
mypy_path = stubs
+files =
+ synapse/api,
+ synapse/appservice,
+ synapse/config,
+ synapse/event_auth.py,
+ synapse/events/builder.py,
+ synapse/events/spamcheck.py,
+ synapse/federation,
+ synapse/handlers/auth.py,
+ synapse/handlers/cas_handler.py,
+ synapse/handlers/directory.py,
+ synapse/handlers/federation.py,
+ synapse/handlers/identity.py,
+ synapse/handlers/message.py,
+ synapse/handlers/oidc_handler.py,
+ synapse/handlers/presence.py,
+ synapse/handlers/room.py,
+ synapse/handlers/room_member.py,
+ synapse/handlers/room_member_worker.py,
+ synapse/handlers/saml_handler.py,
+ synapse/handlers/sync.py,
+ synapse/handlers/ui_auth,
+ synapse/http/server.py,
+ synapse/http/site.py,
+ synapse/logging/,
+ synapse/metrics,
+ synapse/module_api,
+ synapse/notifier.py,
+ synapse/push/pusherpool.py,
+ synapse/push/push_rule_evaluator.py,
+ synapse/replication,
+ synapse/rest,
+ synapse/server.py,
+ synapse/server_notices,
+ synapse/spam_checker_api,
+ synapse/state,
+ synapse/storage/databases/main/ui_auth.py,
+ synapse/storage/database.py,
+ synapse/storage/engines,
+ synapse/storage/state.py,
+ synapse/storage/util,
+ synapse/streams,
+ synapse/types.py,
+ synapse/util/caches/stream_change_cache.py,
+ synapse/util/metrics.py,
+ tests/replication,
+ tests/test_utils,
+ tests/rest/client/v2_alpha/test_auth.py,
+ tests/util/test_stream_change_cache.py
[mypy-pymacaroons.*]
ignore_missing_imports = True
diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py
index 531010185d..ad12523c4d 100755
--- a/scripts-dev/federation_client.py
+++ b/scripts-dev/federation_client.py
@@ -21,10 +21,12 @@ import argparse
import base64
import json
import sys
+from typing import Any, Optional
from urllib import parse as urlparse
import nacl.signing
import requests
+import signedjson.types
import srvlookup
import yaml
from requests.adapters import HTTPAdapter
@@ -69,7 +71,9 @@ def encode_canonical_json(value):
).encode("UTF-8")
-def sign_json(json_object, signing_key, signing_name):
+def sign_json(
+ json_object: Any, signing_key: signedjson.types.SigningKey, signing_name: str
+) -> Any:
signatures = json_object.pop("signatures", {})
unsigned = json_object.pop("unsigned", None)
@@ -122,7 +126,14 @@ def read_signing_keys(stream):
return keys
-def request_json(method, origin_name, origin_key, destination, path, content):
+def request(
+ method: Optional[str],
+ origin_name: str,
+ origin_key: signedjson.types.SigningKey,
+ destination: str,
+ path: str,
+ content: Optional[str],
+) -> requests.Response:
if method is None:
if content is None:
method = "GET"
@@ -159,11 +170,14 @@ def request_json(method, origin_name, origin_key, destination, path, content):
if method == "POST":
headers["Content-Type"] = "application/json"
- result = s.request(
- method=method, url=dest, headers=headers, verify=False, data=content
+ return s.request(
+ method=method,
+ url=dest,
+ headers=headers,
+ verify=False,
+ data=content,
+ stream=True,
)
- sys.stderr.write("Status Code: %d\n" % (result.status_code,))
- return result.json()
def main():
@@ -222,7 +236,7 @@ def main():
with open(args.signing_key_path) as f:
key = read_signing_keys(f)[0]
- result = request_json(
+ result = request(
args.method,
args.server_name,
key,
@@ -231,7 +245,12 @@ def main():
content=args.body,
)
- json.dump(result, sys.stdout)
+ sys.stderr.write("Status Code: %d\n" % (result.status_code,))
+
+ for chunk in result.iter_content():
+ # we write raw utf8 to stdout.
+ sys.stdout.buffer.write(chunk)
+
print("")
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 755a52a50d..14b03490fa 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -96,12 +96,7 @@ class MessageHandler(object):
)
async def get_room_data(
- self,
- user_id: str,
- room_id: str,
- event_type: str,
- state_key: str,
- is_guest: bool,
+ self, user_id: str, room_id: str, event_type: str, state_key: str,
) -> dict:
""" Get data from a room.
@@ -110,11 +105,10 @@ class MessageHandler(object):
room_id
event_type
state_key
- is_guest
Returns:
The path data content.
Raises:
- SynapseError if something went wrong.
+ SynapseError or AuthError if the user is not in the room
"""
(
membership,
@@ -131,6 +125,16 @@ class MessageHandler(object):
[membership_event_id], StateFilter.from_types([key])
)
data = room_state[membership_event_id].get(key)
+ else:
+ # check_user_in_room_or_world_readable, if it doesn't raise an AuthError, should
+ # only ever return a Membership.JOIN/LEAVE object
+ #
+ # Safeguard in case it returned something else
+ logger.error(
+ "Attempted to retrieve data from a room for a user that has never been in it. "
+ "This should not have happened."
+ )
+ raise SynapseError(403, "User not in room", errcode=Codes.FORBIDDEN)
return data
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 4f3198896e..adb9dc7c42 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -15,11 +15,12 @@
# limitations under the License.
import logging
+import random
from typing import List
from signedjson.sign import sign_json
-from twisted.internet import defer, reactor
+from twisted.internet import reactor
from synapse.api.errors import (
AuthError,
@@ -73,36 +74,45 @@ class BaseProfileHandler(BaseHandler):
)
if len(self.hs.config.replicate_user_profiles_to) > 0:
- reactor.callWhenRunning(self._assign_profile_replication_batches)
- reactor.callWhenRunning(self._replicate_profiles)
+ reactor.callWhenRunning(self._do_assign_profile_replication_batches)
+ reactor.callWhenRunning(self._start_replicate_profiles)
# Add a looping call to replicate_profiles: this handles retries
# if the replication is unsuccessful when the user updated their
# profile.
self.clock.looping_call(
- self._replicate_profiles, self.PROFILE_REPLICATE_INTERVAL
+ self._start_replicate_profiles, self.PROFILE_REPLICATE_INTERVAL
)
- @defer.inlineCallbacks
- def _assign_profile_replication_batches(self):
+ def _do_assign_profile_replication_batches(self):
+ return run_as_background_process(
+ "_assign_profile_replication_batches",
+ self._assign_profile_replication_batches,
+ )
+
+ def _start_replicate_profiles(self):
+ return run_as_background_process(
+ "_replicate_profiles", self._replicate_profiles
+ )
+
+ async def _assign_profile_replication_batches(self):
"""If no profile replication has been done yet, allocate replication batch
numbers to each profile to start the replication process.
"""
logger.info("Assigning profile batch numbers...")
total = 0
while True:
- assigned = yield self.store.assign_profile_batch()
+ assigned = await self.store.assign_profile_batch()
total += assigned
if assigned == 0:
break
logger.info("Assigned %d profile batch numbers", total)
- @defer.inlineCallbacks
- def _replicate_profiles(self):
+ async def _replicate_profiles(self):
"""If any profile data has been updated and not pushed to the replication targets,
replicate it.
"""
- host_batches = yield self.store.get_replication_hosts()
- latest_batch = yield self.store.get_latest_profile_replication_batch_number()
+ host_batches = await self.store.get_replication_hosts()
+ latest_batch = await self.store.get_latest_profile_replication_batch_number()
if latest_batch is None:
latest_batch = -1
for repl_host in self.hs.config.replicate_user_profiles_to:
@@ -110,16 +120,15 @@ class BaseProfileHandler(BaseHandler):
host_batches[repl_host] = -1
try:
for i in range(host_batches[repl_host] + 1, latest_batch + 1):
- yield self._replicate_host_profile_batch(repl_host, i)
+ await self._replicate_host_profile_batch(repl_host, i)
except Exception:
logger.exception(
"Exception while replicating to %s: aborting for now", repl_host
)
- @defer.inlineCallbacks
- def _replicate_host_profile_batch(self, host, batchnum):
+ async def _replicate_host_profile_batch(self, host, batchnum):
logger.info("Replicating profile batch %d to %s", batchnum, host)
- batch_rows = yield self.store.get_profile_batch(batchnum)
+ batch_rows = await self.store.get_profile_batch(batchnum)
batch = {
UserID(r["user_id"], self.hs.hostname).to_string(): (
{"display_name": r["displayname"], "avatar_url": r["avatar_url"]}
@@ -133,13 +142,11 @@ class BaseProfileHandler(BaseHandler):
body = {"batchnum": batchnum, "batch": batch, "origin_server": self.hs.hostname}
signed_body = sign_json(body, self.hs.hostname, self.hs.config.signing_key[0])
try:
- yield defer.ensureDeferred(
- self.http_client.post_json_get_json(url, signed_body)
- )
- yield defer.ensureDeferred(
- self.store.update_replication_batch_for_host(host, batchnum)
+ await self.http_client.post_json_get_json(url, signed_body)
+ await self.store.update_replication_batch_for_host(host, batchnum)
+ logger.info(
+ "Successfully replicated profile batch %d to %s", batchnum, host
)
- logger.info("Sucessfully replicated profile batch %d to %s", batchnum, host)
except Exception:
# This will get retried when the looping call next comes around
logger.exception(
@@ -292,8 +299,7 @@ class BaseProfileHandler(BaseHandler):
# start a profile replication push
run_in_background(self._replicate_profiles)
- @defer.inlineCallbacks
- def set_active(
+ async def set_active(
self, users: List[UserID], active: bool, hide: bool,
):
"""
@@ -316,19 +322,16 @@ class BaseProfileHandler(BaseHandler):
hide: Whether to hide the user (withold from replication). If
False and active is False, user will have their profile
erased
-
- Returns:
- Deferred
"""
if len(self.replicate_user_profiles_to) > 0:
cur_batchnum = (
- yield self.store.get_latest_profile_replication_batch_number()
+ await self.store.get_latest_profile_replication_batch_number()
)
new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
else:
new_batchnum = None
- yield self.store.set_profiles_active(users, active, hide, new_batchnum)
+ await self.store.set_profiles_active(users, active, hide, new_batchnum)
# start a profile replication push
run_in_background(self._replicate_profiles)
@@ -362,8 +365,14 @@ class BaseProfileHandler(BaseHandler):
async def set_avatar_url(
self, target_user, requester, new_avatar_url, by_admin=False
):
- """target_user is the user whose avatar_url is to be changed;
- auth_user is the user attempting to make this change."""
+ """Set a new avatar URL for a user.
+
+ Args:
+ target_user (UserID): the user whose avatar URL is to be changed.
+ requester (Requester): The user attempting to make this change.
+ new_avatar_url (str): The avatar URL to give this user.
+ by_admin (bool): Whether this change was made by an administrator.
+ """
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this homeserver")
@@ -484,6 +493,12 @@ class BaseProfileHandler(BaseHandler):
await self.ratelimit(requester)
+ # Do not actually update the room state for shadow-banned users.
+ if requester.shadow_banned:
+ # We randomly sleep a bit just to annoy the requester.
+ await self.clock.sleep(random.randint(1, 10))
+ return
+
room_ids = await self.store.get_rooms_for_user(target_user.to_string())
for room_id in room_ids:
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 8ee9b2063d..50f3756446 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -384,7 +384,7 @@ class RoomMemberHandler(object):
# later on.
content = dict(content)
- if not self.allow_per_room_profiles:
+ if not self.allow_per_room_profiles or requester.shadow_banned:
# Strip profile data, knowing that new profile data will be added to the
# event's content in event_creation_handler.create_event() using the target's
# global profile.
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index a86ac0150e..1d828bd7be 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -14,10 +14,11 @@
# limitations under the License.
import logging
+import random
from collections import namedtuple
from typing import TYPE_CHECKING, List, Set, Tuple
-from synapse.api.errors import AuthError, SynapseError
+from synapse.api.errors import AuthError, ShadowBanError, SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.streams import TypingStream
from synapse.types import UserID, get_domain_from_id
@@ -227,9 +228,9 @@ class TypingWriterHandler(FollowerTypingHandler):
self._stopped_typing(member)
return
- async def started_typing(self, target_user, auth_user, room_id, timeout):
+ async def started_typing(self, target_user, requester, room_id, timeout):
target_user_id = target_user.to_string()
- auth_user_id = auth_user.to_string()
+ auth_user_id = requester.user.to_string()
if not self.is_mine_id(target_user_id):
raise SynapseError(400, "User is not hosted on this homeserver")
@@ -237,6 +238,11 @@ class TypingWriterHandler(FollowerTypingHandler):
if target_user_id != auth_user_id:
raise AuthError(400, "Cannot set another user's typing state")
+ if requester.shadow_banned:
+ # We randomly sleep a bit just to annoy the requester.
+ await self.clock.sleep(random.randint(1, 10))
+ raise ShadowBanError()
+
await self.auth.check_user_in_room(room_id, target_user_id)
logger.debug("%s has started typing in %s", target_user_id, room_id)
@@ -256,9 +262,9 @@ class TypingWriterHandler(FollowerTypingHandler):
self._push_update(member=member, typing=True)
- async def stopped_typing(self, target_user, auth_user, room_id):
+ async def stopped_typing(self, target_user, requester, room_id):
target_user_id = target_user.to_string()
- auth_user_id = auth_user.to_string()
+ auth_user_id = requester.user.to_string()
if not self.is_mine_id(target_user_id):
raise SynapseError(400, "User is not hosted on this homeserver")
@@ -266,6 +272,11 @@ class TypingWriterHandler(FollowerTypingHandler):
if target_user_id != auth_user_id:
raise AuthError(400, "Cannot set another user's typing state")
+ if requester.shadow_banned:
+ # We randomly sleep a bit just to annoy the requester.
+ await self.clock.sleep(random.randint(1, 10))
+ raise ShadowBanError()
+
await self.auth.check_user_in_room(room_id, target_user_id)
logger.debug("%s has stopped typing in %s", target_user_id, room_id)
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index d43eaf3a29..047f2c50f7 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -21,9 +21,9 @@ class SlavedIdTracker(object):
self.step = step
self._current = _load_current_id(db_conn, table, column, step)
for table, column in extra_tables:
- self.advance(_load_current_id(db_conn, table, column))
+ self.advance(None, _load_current_id(db_conn, table, column))
- def advance(self, new_id):
+ def advance(self, instance_name, new_id):
self._current = (max if self.step > 0 else min)(self._current, new_id)
def get_current_token(self):
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 154f0e687c..bb66ba9b80 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -41,12 +41,12 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == TagAccountDataStream.NAME:
- self._account_data_id_gen.advance(token)
+ self._account_data_id_gen.advance(instance_name, token)
for row in rows:
self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
elif stream_name == AccountDataStream.NAME:
- self._account_data_id_gen.advance(token)
+ self._account_data_id_gen.advance(instance_name, token)
for row in rows:
if not row.room_id:
self.get_global_account_data_by_type_for_user.invalidate(
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index ee7f69a918..533d927701 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -46,7 +46,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == ToDeviceStream.NAME:
- self._device_inbox_id_gen.advance(token)
+ self._device_inbox_id_gen.advance(instance_name, token)
for row in rows:
if row.entity.startswith("@"):
self._device_inbox_stream_cache.entity_has_changed(
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 722f3745e9..596c72eb92 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -50,10 +50,10 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == DeviceListsStream.NAME:
- self._device_list_id_gen.advance(token)
+ self._device_list_id_gen.advance(instance_name, token)
self._invalidate_caches_for_devices(token, rows)
elif stream_name == UserSignatureStream.NAME:
- self._device_list_id_gen.advance(token)
+ self._device_list_id_gen.advance(instance_name, token)
for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 3291558c7a..567b4a5cc1 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -40,7 +40,7 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == GroupServerStream.NAME:
- self._group_updates_id_gen.advance(token)
+ self._group_updates_id_gen.advance(instance_name, token)
for row in rows:
self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index a912c04360..025f6f6be8 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -44,7 +44,7 @@ class SlavedPresenceStore(BaseSlavedStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == PresenceStream.NAME:
- self._presence_id_gen.advance(token)
+ self._presence_id_gen.advance(instance_name, token)
for row in rows:
self.presence_stream_cache.entity_has_changed(row.user_id, token)
self._get_presence_for_user.invalidate((row.user_id,))
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 90d90833f9..de904c943c 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -30,7 +30,7 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
if stream_name == PushRulesStream.NAME:
- self._push_rules_stream_id_gen.advance(token)
+ self._push_rules_stream_id_gen.advance(instance_name, token)
for row in rows:
self.get_push_rules_for_user.invalidate((row.user_id,))
self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index 63300e5da6..9da218bfe8 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -34,5 +34,5 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == PushersStream.NAME:
- self._pushers_id_gen.advance(token)
+ self._pushers_id_gen.advance(instance_name, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 17ba1f22ac..5c2986e050 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -46,7 +46,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == ReceiptsStream.NAME:
- self._receipts_id_gen.advance(token)
+ self._receipts_id_gen.advance(instance_name, token)
for row in rows:
self.invalidate_caches_for_receipt(
row.room_id, row.receipt_type, row.user_id
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 427c81772b..80ae803ad9 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -33,6 +33,6 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == PublicRoomsStream.NAME:
- self._public_room_id_gen.advance(token)
+ self._public_room_id_gen.advance(instance_name, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 7ed1ccb5a0..3929d5519d 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -171,7 +171,6 @@ class RoomStateEventRestServlet(TransactionRestServlet):
room_id=room_id,
event_type=event_type,
state_key=state_key,
- is_guest=requester.is_guest,
)
if not data:
@@ -870,17 +869,21 @@ class RoomTypingRestServlet(RestServlet):
# Limit timeout to stop people from setting silly typing timeouts.
timeout = min(content.get("timeout", 30000), 120000)
- if content["typing"]:
- await self.typing_handler.started_typing(
- target_user=target_user,
- auth_user=requester.user,
- room_id=room_id,
- timeout=timeout,
- )
- else:
- await self.typing_handler.stopped_typing(
- target_user=target_user, auth_user=requester.user, room_id=room_id
- )
+ try:
+ if content["typing"]:
+ await self.typing_handler.started_typing(
+ target_user=target_user,
+ requester=requester,
+ room_id=room_id,
+ timeout=timeout,
+ )
+ else:
+ await self.typing_handler.stopped_typing(
+ target_user=target_user, requester=requester, room_id=room_id
+ )
+ except ShadowBanError:
+ # Pretend this worked without error.
+ pass
return 200, {}
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index bc327e344e..181c3ec249 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -29,9 +29,11 @@ from typing import (
Tuple,
TypeVar,
Union,
+ overload,
)
from prometheus_client import Histogram
+from typing_extensions import Literal
from twisted.enterprise import adbapi
from twisted.internet import defer
@@ -1020,14 +1022,36 @@ class DatabasePool(object):
return txn.execute_batch(sql, args)
- def simple_select_one(
+ @overload
+ async def simple_select_one(
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcols: Iterable[str],
+ allow_none: Literal[False] = False,
+ desc: str = "simple_select_one",
+ ) -> Dict[str, Any]:
+ ...
+
+ @overload
+ async def simple_select_one(
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcols: Iterable[str],
+ allow_none: Literal[True] = True,
+ desc: str = "simple_select_one",
+ ) -> Optional[Dict[str, Any]]:
+ ...
+
+ async def simple_select_one(
self,
table: str,
keyvalues: Dict[str, Any],
retcols: Iterable[str],
allow_none: bool = False,
desc: str = "simple_select_one",
- ) -> defer.Deferred:
+ ) -> Optional[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning multiple columns from it.
@@ -1038,18 +1062,18 @@ class DatabasePool(object):
allow_none: If true, return None instead of failing if the SELECT
statement returns no rows
"""
- return self.runInteraction(
+ return await self.runInteraction(
desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
)
- def simple_select_one_onecol(
+ async def simple_select_one_onecol(
self,
table: str,
keyvalues: Dict[str, Any],
retcol: Iterable[str],
allow_none: bool = False,
desc: str = "simple_select_one_onecol",
- ) -> defer.Deferred:
+ ) -> Optional[Any]:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it.
@@ -1061,7 +1085,7 @@ class DatabasePool(object):
statement returns no rows
desc: description of the transaction, for logging and metrics
"""
- return self.runInteraction(
+ return await self.runInteraction(
desc,
self.simple_select_one_onecol_txn,
table,
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 03b45dbc4d..a811a39eb5 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Dict, Iterable, List, Optional, Set, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.errors import Codes, StoreError
from synapse.logging.opentracing import (
@@ -47,7 +47,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
class DeviceWorkerStore(SQLBaseStore):
- def get_device(self, user_id: str, device_id: str):
+ async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
"""Retrieve a device. Only returns devices that are not marked as
hidden.
@@ -55,11 +55,11 @@ class DeviceWorkerStore(SQLBaseStore):
user_id: The ID of the user which owns the device
device_id: The ID of the device to retrieve
Returns:
- defer.Deferred for a dict containing the device information
+ A dict containing the device information
Raises:
StoreError: if the device is not found
"""
- return self.db_pool.simple_select_one(
+ return await self.db_pool.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
@@ -656,11 +656,13 @@ class DeviceWorkerStore(SQLBaseStore):
)
@cached(max_entries=10000)
- def get_device_list_last_stream_id_for_remote(self, user_id: str):
+ async def get_device_list_last_stream_id_for_remote(
+ self, user_id: str
+ ) -> Optional[Any]:
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
retcol="stream_id",
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 037e02603c..301d5d845a 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -59,8 +59,8 @@ class DirectoryWorkerStore(SQLBaseStore):
return RoomAliasMapping(room_id, room_alias.to_string(), servers)
- def get_room_alias_creator(self, room_alias):
- return self.db_pool.simple_select_one_onecol(
+ async def get_room_alias_creator(self, room_alias: str) -> str:
+ return await self.db_pool.simple_select_one_onecol(
table="room_aliases",
keyvalues={"room_alias": room_alias},
retcol="creator",
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index 2eeb9f97dc..46c3e33cc6 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -223,15 +223,15 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return ret
- def count_e2e_room_keys(self, user_id, version):
+ async def count_e2e_room_keys(self, user_id: str, version: str) -> int:
"""Get the number of keys in a backup version.
Args:
- user_id (str): the user whose backup we're querying
- version (str): the version ID of the backup we're querying about
+ user_id: the user whose backup we're querying
+ version: the version ID of the backup we're querying about
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="e2e_room_keys",
keyvalues={"user_id": user_id, "version": version},
retcol="COUNT(*)",
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index e1241a724b..e6247d682d 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -113,25 +113,25 @@ class EventsWorkerStore(SQLBaseStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == EventsStream.NAME:
- self._stream_id_gen.advance(token)
+ self._stream_id_gen.advance(instance_name, token)
elif stream_name == BackfillStream.NAME:
- self._backfill_id_gen.advance(-token)
+ self._backfill_id_gen.advance(instance_name, -token)
super().process_replication_rows(stream_name, instance_name, token, rows)
- def get_received_ts(self, event_id):
+ async def get_received_ts(self, event_id: str) -> Optional[int]:
"""Get received_ts (when it was persisted) for the event.
Raises an exception for unknown events.
Args:
- event_id (str)
+ event_id: The event ID to query.
Returns:
- Deferred[int|None]: Timestamp in milliseconds, or None for events
- that were persisted before received_ts was implemented.
+ Timestamp in milliseconds, or None for events that were persisted
+ before received_ts was implemented.
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="events",
keyvalues={"event_id": event_id},
retcol="received_ts",
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index a488e0924b..c39864f59f 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -28,8 +28,8 @@ _DEFAULT_ROLE_ID = ""
class GroupServerWorkerStore(SQLBaseStore):
- def get_group(self, group_id):
- return self.db_pool.simple_select_one(
+ async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
+ return await self.db_pool.simple_select_one(
table="groups",
keyvalues={"group_id": group_id},
retcols=(
@@ -351,8 +351,10 @@ class GroupServerWorkerStore(SQLBaseStore):
)
return bool(result)
- def is_user_admin_in_group(self, group_id, user_id):
- return self.db_pool.simple_select_one_onecol(
+ async def is_user_admin_in_group(
+ self, group_id: str, user_id: str
+ ) -> Optional[bool]:
+ return await self.db_pool.simple_select_one_onecol(
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="is_admin",
@@ -360,10 +362,12 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="is_user_admin_in_group",
)
- def is_user_invited_to_local_group(self, group_id, user_id):
+ async def is_user_invited_to_local_group(
+ self, group_id: str, user_id: str
+ ) -> Optional[bool]:
"""Has the group server invited a user?
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id",
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 80fc1cd009..4ae255ebd8 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict, Optional
+
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
@@ -37,12 +39,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
- def get_local_media(self, media_id):
+ async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
"""Get the metadata for a local piece of media
+
Returns:
None if the media_id doesn't exist.
"""
- return self.db_pool.simple_select_one(
+ return await self.db_pool.simple_select_one(
"local_media_repository",
{"media_id": media_id},
(
@@ -191,8 +194,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_local_thumbnail",
)
- def get_cached_remote_media(self, origin, media_id):
- return self.db_pool.simple_select_one(
+ async def get_cached_remote_media(
+ self, origin, media_id: str
+ ) -> Optional[Dict[str, Any]]:
+ return await self.db_pool.simple_select_one(
"remote_media_cache",
{"media_origin": origin, "media_id": media_id},
(
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index e71cdd2cb4..fe30552c08 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -99,17 +99,18 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
return users
@cached(num_args=1)
- def user_last_seen_monthly_active(self, user_id):
+ async def user_last_seen_monthly_active(self, user_id: str) -> int:
"""
- Checks if a given user is part of the monthly active user group
- Arguments:
- user_id (str): user to add/update
- Return:
- Deferred[int] : timestamp since last seen, None if never seen
+ Checks if a given user is part of the monthly active user group
+ Arguments:
+ user_id: user to add/update
+
+ Return:
+ Timestamp since last seen, None if never seen
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="monthly_active_users",
keyvalues={"user_id": user_id},
retcol="timestamp",
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 086cfbeed4..f607b823f8 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -13,8 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from typing import List, Tuple
+from typing import Any, Dict, List, Optional, Tuple
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
@@ -26,7 +25,7 @@ BATCH_SIZE = 100
class ProfileWorkerStore(SQLBaseStore):
- async def get_profileinfo(self, user_localpart):
+ async def get_profileinfo(self, user_localpart: str) -> ProfileInfo:
try:
profile = await self.db_pool.simple_select_one(
table="profiles",
@@ -46,8 +45,8 @@ class ProfileWorkerStore(SQLBaseStore):
)
@cached(max_entries=5000)
- def get_profile_displayname(self, user_localpart):
- return self.db_pool.simple_select_one_onecol(
+ async def get_profile_displayname(self, user_localpart: str) -> str:
+ return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="displayname",
@@ -55,33 +54,33 @@ class ProfileWorkerStore(SQLBaseStore):
)
@cached(max_entries=5000)
- def get_profile_avatar_url(self, user_localpart):
- return self.db_pool.simple_select_one_onecol(
+ async def get_profile_avatar_url(self, user_localpart: str) -> str:
+ return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="avatar_url",
desc="get_profile_avatar_url",
)
- def get_latest_profile_replication_batch_number(self):
+ async def get_latest_profile_replication_batch_number(self):
def f(txn):
txn.execute("SELECT MAX(batch) as maxbatch FROM profiles")
rows = self.db_pool.cursor_to_dict(txn)
return rows[0]["maxbatch"]
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_latest_profile_replication_batch_number", f
)
- def get_profile_batch(self, batchnum):
- return self.db_pool.simple_select_list(
+ async def get_profile_batch(self, batchnum):
+ return await self.db_pool.simple_select_list(
table="profiles",
keyvalues={"batch": batchnum},
retcols=("user_id", "displayname", "avatar_url", "active"),
desc="get_profile_batch",
)
- def assign_profile_batch(self):
+ async def assign_profile_batch(self):
def f(txn):
sql = (
"UPDATE profiles SET batch = "
@@ -93,9 +92,9 @@ class ProfileWorkerStore(SQLBaseStore):
txn.execute(sql, (BATCH_SIZE,))
return txn.rowcount
- return self.db_pool.runInteraction("assign_profile_batch", f)
+ return await self.db_pool.runInteraction("assign_profile_batch", f)
- def get_replication_hosts(self):
+ async def get_replication_hosts(self):
def f(txn):
txn.execute(
"SELECT host, last_synced_batch FROM profile_replication_status"
@@ -103,18 +102,22 @@ class ProfileWorkerStore(SQLBaseStore):
rows = self.db_pool.cursor_to_dict(txn)
return {r["host"]: r["last_synced_batch"] for r in rows}
- return self.db_pool.runInteraction("get_replication_hosts", f)
+ return await self.db_pool.runInteraction("get_replication_hosts", f)
- def update_replication_batch_for_host(self, host, last_synced_batch):
- return self.db_pool.simple_upsert(
+ async def update_replication_batch_for_host(
+ self, host: str, last_synced_batch: int
+ ):
+ return await self.db_pool.simple_upsert(
table="profile_replication_status",
keyvalues={"host": host},
values={"last_synced_batch": last_synced_batch},
desc="update_replication_batch_for_host",
)
- def get_from_remote_profile_cache(self, user_id):
- return self.db_pool.simple_select_one(
+ async def get_from_remote_profile_cache(
+ self, user_id: str
+ ) -> Optional[Dict[str, Any]]:
+ return await self.db_pool.simple_select_one(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
retcols=("displayname", "avatar_url"),
@@ -151,9 +154,9 @@ class ProfileWorkerStore(SQLBaseStore):
lock=False, # we can do this because user_id has a unique index
)
- def set_profiles_active(
+ async def set_profiles_active(
self, users: List[UserID], active: bool, hide: bool, batchnum: int,
- ):
+ ) -> None:
"""Given a set of users, set active and hidden flags on them.
Args:
@@ -163,9 +166,6 @@ class ProfileWorkerStore(SQLBaseStore):
False and active is False, users will have their profiles
erased
batchnum: The batch number, used for profile replication
-
- Returns:
- Deferred
"""
# Convert list of localparts to list of tuples containing localparts
user_localparts = [(user.localpart,) for user in users]
@@ -180,7 +180,7 @@ class ProfileWorkerStore(SQLBaseStore):
value_names += ("avatar_url", "displayname")
values = [v + (None, None) for v in values]
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"set_profiles_active",
self.db_pool.simple_upsert_many_txn,
table="profiles",
@@ -225,7 +225,7 @@ class ProfileStore(ProfileWorkerStore):
return self.db_pool.simple_upsert(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
- updatevalues={
+ values={
"displayname": displayname,
"avatar_url": avatar_url,
"last_check": self._clock.time_msec(),
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 6821476ee0..cea5ac9a68 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -71,8 +71,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
@cached(num_args=3)
- def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
- return self.db_pool.simple_select_one_onecol(
+ async def get_last_receipt_event_id_for_user(
+ self, user_id: str, room_id: str, receipt_type: str
+ ) -> Optional[str]:
+ return await self.db_pool.simple_select_one_onecol(
table="receipts_linearized",
keyvalues={
"room_id": room_id,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 336b578e23..48c979aeea 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -17,7 +17,7 @@
import logging
import re
-from typing import Awaitable, Dict, List, Optional
+from typing import Any, Awaitable, Dict, List, Optional
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@@ -46,8 +46,8 @@ class RegistrationWorkerStore(SQLBaseStore):
)
@cached()
- def get_user_by_id(self, user_id):
- return self.db_pool.simple_select_one(
+ async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
+ return await self.db_pool.simple_select_one(
table="users",
keyvalues={"name": user_id},
retcols=[
@@ -1338,12 +1338,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="del_user_pending_deactivation",
)
- def get_user_pending_deactivation(self):
+ async def get_user_pending_deactivation(self) -> Optional[str]:
"""
Gets one user from the table of users waiting to be parted from all the rooms
they're in.
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
"users_pending_deactivation",
keyvalues={},
retcol="user_id",
diff --git a/synapse/storage/databases/main/rejections.py b/synapse/storage/databases/main/rejections.py
index cf9ba51205..1e361aaa9a 100644
--- a/synapse/storage/databases/main/rejections.py
+++ b/synapse/storage/databases/main/rejections.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import Optional
from synapse.storage._base import SQLBaseStore
@@ -21,8 +22,8 @@ logger = logging.getLogger(__name__)
class RejectionsStore(SQLBaseStore):
- def get_rejection_reason(self, event_id):
- return self.db_pool.simple_select_one_onecol(
+ async def get_rejection_reason(self, event_id: str) -> Optional[str]:
+ return await self.db_pool.simple_select_one_onecol(
table="rejections",
retcol="reason",
keyvalues={"event_id": event_id},
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 99a8a9fab0..dc97f70c66 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -73,15 +73,15 @@ class RoomWorkerStore(SQLBaseStore):
self.config = hs.config
- def get_room(self, room_id):
+ async def get_room(self, room_id: str) -> dict:
"""Retrieve a room.
Args:
- room_id (str): The ID of the room to retrieve.
+ room_id: The ID of the room to retrieve.
Returns:
A dict containing the room information, or None if the room is unknown.
"""
- return self.db_pool.simple_select_one(
+ return await self.db_pool.simple_select_one(
table="rooms",
keyvalues={"room_id": room_id},
retcols=("room_id", "is_public", "creator"),
@@ -330,8 +330,8 @@ class RoomWorkerStore(SQLBaseStore):
return ret_val
@cached(max_entries=10000)
- def is_room_blocked(self, room_id):
- return self.db_pool.simple_select_one_onecol(
+ async def is_room_blocked(self, room_id: str) -> Optional[bool]:
+ return await self.db_pool.simple_select_one_onecol(
table="blocked_rooms",
keyvalues={"room_id": room_id},
retcol="1",
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 991233a9bc..458f169617 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -260,8 +260,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return event.content.get("canonical_alias")
@cached(max_entries=50000)
- def _get_state_group_for_event(self, event_id):
- return self.db_pool.simple_select_one_onecol(
+ async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
+ return await self.db_pool.simple_select_one_onecol(
table="event_to_state_groups",
keyvalues={"event_id": event_id},
retcol="state_group",
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 802c9019b9..9fe97af56a 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -211,11 +211,11 @@ class StatsStore(StateDeltasStore):
return len(rooms_to_work_on)
- def get_stats_positions(self):
+ async def get_stats_positions(self) -> int:
"""
Returns the stats processor positions.
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="stats_incremental_position",
keyvalues={},
retcol="stream_id",
@@ -300,7 +300,7 @@ class StatsStore(StateDeltasStore):
return slice_list
@cached()
- def get_earliest_token_for_stats(self, stats_type, id):
+ async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int:
"""
Fetch the "earliest token". This is used by the room stats delta
processor to ignore deltas that have been processed between the
@@ -308,11 +308,11 @@ class StatsStore(StateDeltasStore):
being calculated.
Returns:
- Deferred[int]
+ The earliest token.
"""
table, id_col = TYPE_TO_TABLE[stats_type]
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
"%s_current" % (table,),
keyvalues={id_col: id},
retcol="completed_delta_stream_id",
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index af21fe457a..20cbcd851c 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -15,6 +15,7 @@
import logging
import re
+from typing import Any, Dict, Optional
from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.database import DatabasePool
@@ -527,8 +528,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
@cached()
- def get_user_in_directory(self, user_id):
- return self.db_pool.simple_select_one(
+ async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]:
+ return await self.db_pool.simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
retcols=("display_name", "avatar_url"),
@@ -663,8 +664,8 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
users.update(rows)
return list(users)
- def get_user_directory_stream_pos(self):
- return self.db_pool.simple_select_one_onecol(
+ async def get_user_directory_stream_pos(self) -> int:
+ return await self.db_pool.simple_select_one_onecol(
table="user_directory_stream_pos",
keyvalues={},
retcol="stream_id",
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index d7f0c19c4c..3cd4f71d40 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -106,7 +106,12 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
- (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_displayname(self.frank.localpart)
+ )
+ ),
+ "Frank",
)
@defer.inlineCallbacks
@@ -201,7 +206,11 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
- (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_avatar_url(self.frank.localpart)
+ )
+ ),
"http://my.server/pic.gif",
)
@@ -215,7 +224,11 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
- (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_avatar_url(self.frank.localpart)
+ )
+ ),
"http://my.server/me.png",
)
@@ -231,7 +244,11 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
- (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_avatar_url(self.frank.localpart)
+ )
+ ),
"http://my.server/me.png",
)
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index 4b627dac00..2b7eeef129 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -262,7 +262,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# self.handler.notify_new_event()
# We need to let the delta processor advanceā¦
- self.pump(10 * 60)
+ self.reactor.advance(10 * 60)
# Get the slices! There should be two -- day 1, and day 2.
r = self.get_success(self.store.get_statistics_for_subject("room", room_1, 0))
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index e01de158e5..81c1839637 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -21,7 +21,7 @@ from mock import ANY, Mock, call
from twisted.internet import defer
from synapse.api.errors import AuthError
-from synapse.types import UserID
+from synapse.types import UserID, create_requester
from tests import unittest
from tests.test_utils import make_awaitable
@@ -144,9 +144,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_users_in_room = get_users_in_room
- self.datastore.get_user_directory_stream_pos.return_value = (
+ self.datastore.get_user_directory_stream_pos.side_effect = (
# we deliberately return a non-None stream pos to avoid doing an initial_spam
- defer.succeed(1)
+ lambda: make_awaitable(1)
)
self.datastore.get_current_state_deltas.return_value = (0, None)
@@ -167,7 +167,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.get_success(
self.handler.started_typing(
- target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000
+ target_user=U_APPLE,
+ requester=create_requester(U_APPLE),
+ room_id=ROOM_ID,
+ timeout=20000,
)
)
@@ -194,7 +197,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.get_success(
self.handler.started_typing(
- target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000
+ target_user=U_APPLE,
+ requester=create_requester(U_APPLE),
+ room_id=ROOM_ID,
+ timeout=20000,
)
)
@@ -269,7 +275,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.get_success(
self.handler.stopped_typing(
- target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID
+ target_user=U_APPLE,
+ requester=create_requester(U_APPLE),
+ room_id=ROOM_ID,
)
)
@@ -309,7 +317,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.get_success(
self.handler.started_typing(
- target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000
+ target_user=U_APPLE,
+ requester=create_requester(U_APPLE),
+ room_id=ROOM_ID,
+ timeout=10000,
)
)
@@ -348,7 +359,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.get_success(
self.handler.started_typing(
- target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000
+ target_user=U_APPLE,
+ requester=create_requester(U_APPLE),
+ room_id=ROOM_ID,
+ timeout=10000,
)
)
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index ac598249e4..7f1dfb2128 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -508,6 +508,6 @@ class FederationClientTests(HomeserverTestCase):
self.assertFalse(conn.disconnecting)
# wait for a while
- self.pump(120)
+ self.reactor.advance(120)
self.assertTrue(conn.disconnecting)
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 9c778a0e45..ccbb82f6a3 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -47,7 +47,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# Check that the new user exists with all provided attributes
self.assertEqual(user_id, "@bob:test")
self.assertTrue(access_token)
- self.assertTrue(self.store.get_user_by_id(user_id))
+ self.assertTrue(self.get_success(self.store.get_user_by_id(user_id)))
# Check that the email was assigned
emails = self.get_success(self.store.user_get_threepids(user_id))
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 83032cc9ea..227b0d32d0 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -170,7 +170,7 @@ class EmailPusherTests(HomeserverTestCase):
last_stream_ordering = pushers[0]["last_stream_ordering"]
# Advance time a bit, so the pusher will register something has happened
- self.pump(100)
+ self.pump(10)
# It hasn't succeeded yet, so the stream ordering shouldn't have moved
pushers = self.get_success(
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 83f9aa291c..8b4982ecb1 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -20,7 +20,7 @@ from synapse.api.constants import EventTypes, Membership
from synapse.events.builder import EventBuilderFactory
from synapse.rest.admin import register_servlets_for_client_rest_resource
from synapse.rest.client.v1 import login, room
-from synapse.types import UserID
+from synapse.types import UserID, create_requester
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.test_utils import make_awaitable
@@ -175,7 +175,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
self.get_success(
typing_handler.started_typing(
target_user=UserID.from_string(user),
- auth_user=UserID.from_string(user),
+ requester=create_requester(user),
room_id=room,
timeout=20000,
)
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index c7e287c61e..47c0d5634c 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -179,7 +179,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
message_handler = self.hs.get_message_handler()
create_event = self.get_success(
message_handler.get_room_data(
- self.user_id, room_id, EventTypes.Create, state_key="", is_guest=False
+ self.user_id, room_id, EventTypes.Create, state_key=""
)
)
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
new file mode 100644
index 0000000000..dfe4bf7762
--- /dev/null
+++ b/tests/rest/client/test_shadow_banned.py
@@ -0,0 +1,312 @@
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mock import Mock, patch
+
+import synapse.rest.admin
+from synapse.api.constants import EventTypes
+from synapse.rest.client.v1 import directory, login, profile, room
+from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet
+
+from tests import unittest
+
+
+class _ShadowBannedBase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, homeserver):
+ # Create two users, one of which is shadow-banned.
+ self.banned_user_id = self.register_user("banned", "test")
+ self.banned_access_token = self.login("banned", "test")
+
+ self.store = self.hs.get_datastore()
+
+ self.get_success(
+ self.store.db_pool.simple_update(
+ table="users",
+ keyvalues={"name": self.banned_user_id},
+ updatevalues={"shadow_banned": True},
+ desc="shadow_ban",
+ )
+ )
+
+ self.other_user_id = self.register_user("otheruser", "pass")
+ self.other_access_token = self.login("otheruser", "pass")
+
+
+# To avoid the tests timing out don't add a delay to "annoy the requester".
+@patch("random.randint", new=lambda a, b: 0)
+class RoomTestCase(_ShadowBannedBase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ directory.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ room_upgrade_rest_servlet.register_servlets,
+ ]
+
+ def test_invite(self):
+ """Invites from shadow-banned users don't actually get sent."""
+
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ # Inviting the user completes successfully.
+ self.helper.invite(
+ room=room_id,
+ src=self.banned_user_id,
+ tok=self.banned_access_token,
+ targ=self.other_user_id,
+ )
+
+ # But the user wasn't actually invited.
+ invited_rooms = self.get_success(
+ self.store.get_invited_rooms_for_local_user(self.other_user_id)
+ )
+ self.assertEqual(invited_rooms, [])
+
+ def test_invite_3pid(self):
+ """Ensure that a 3PID invite does not attempt to contact the identity server."""
+ identity_handler = self.hs.get_handlers().identity_handler
+ identity_handler.lookup_3pid = Mock(
+ side_effect=AssertionError("This should not get called")
+ )
+
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ # Inviting the user completes successfully.
+ request, channel = self.make_request(
+ "POST",
+ "/rooms/%s/invite" % (room_id,),
+ {"id_server": "test", "medium": "email", "address": "test@test.test"},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+
+ # This should have raised an error earlier, but double check this wasn't called.
+ identity_handler.lookup_3pid.assert_not_called()
+
+ def test_create_room(self):
+ """Invitations during a room creation should be discarded, but the room still gets created."""
+ # The room creation is successful.
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/createRoom",
+ {"visibility": "public", "invite": [self.other_user_id]},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+ room_id = channel.json_body["room_id"]
+
+ # But the user wasn't actually invited.
+ invited_rooms = self.get_success(
+ self.store.get_invited_rooms_for_local_user(self.other_user_id)
+ )
+ self.assertEqual(invited_rooms, [])
+
+ # Since a real room was created, the other user should be able to join it.
+ self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
+
+ # Both users should be in the room.
+ users = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertCountEqual(users, ["@banned:test", "@otheruser:test"])
+
+ def test_message(self):
+ """Messages from shadow-banned users don't actually get sent."""
+
+ room_id = self.helper.create_room_as(
+ self.other_user_id, tok=self.other_access_token
+ )
+
+ # The user should be in the room.
+ self.helper.join(room_id, self.banned_user_id, tok=self.banned_access_token)
+
+ # Sending a message should complete successfully.
+ result = self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "with right label"},
+ tok=self.banned_access_token,
+ )
+ self.assertIn("event_id", result)
+ event_id = result["event_id"]
+
+ latest_events = self.get_success(
+ self.store.get_latest_event_ids_in_room(room_id)
+ )
+ self.assertNotIn(event_id, latest_events)
+
+ def test_upgrade(self):
+ """A room upgrade should fail, but look like it succeeded."""
+
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/%s/upgrade" % (room_id,),
+ {"new_version": "6"},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+ # A new room_id should be returned.
+ self.assertIn("replacement_room", channel.json_body)
+
+ new_room_id = channel.json_body["replacement_room"]
+
+ # It doesn't really matter what API we use here, we just want to assert
+ # that the room doesn't exist.
+ summary = self.get_success(self.store.get_room_summary(new_room_id))
+ # The summary should be empty since the room doesn't exist.
+ self.assertEqual(summary, {})
+
+ def test_typing(self):
+ """Typing notifications should not be propagated into the room."""
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ request, channel = self.make_request(
+ "PUT",
+ "/rooms/%s/typing/%s" % (room_id, self.banned_user_id),
+ {"typing": True, "timeout": 30000},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+
+ # There should be no typing events.
+ event_source = self.hs.get_event_sources().sources["typing"]
+ self.assertEquals(event_source.get_current_key(), 0)
+
+ # The other user can join and send typing events.
+ self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
+
+ request, channel = self.make_request(
+ "PUT",
+ "/rooms/%s/typing/%s" % (room_id, self.other_user_id),
+ {"typing": True, "timeout": 30000},
+ access_token=self.other_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+
+ # These appear in the room.
+ self.assertEquals(event_source.get_current_key(), 1)
+ events = self.get_success(
+ event_source.get_new_events(from_key=0, room_ids=[room_id])
+ )
+ self.assertEquals(
+ events[0],
+ [
+ {
+ "type": "m.typing",
+ "room_id": room_id,
+ "content": {"user_ids": [self.other_user_id]},
+ }
+ ],
+ )
+
+
+# To avoid the tests timing out don't add a delay to "annoy the requester".
+@patch("random.randint", new=lambda a, b: 0)
+class ProfileTestCase(_ShadowBannedBase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ profile.register_servlets,
+ room.register_servlets,
+ ]
+
+ def test_displayname(self):
+ """Profile changes should succeed, but don't end up in a room."""
+ original_display_name = "banned"
+ new_display_name = "new name"
+
+ # Join a room.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ # The update should succeed.
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/profile/%s/displayname" % (self.banned_user_id,),
+ {"displayname": new_display_name},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEqual(channel.json_body, {})
+
+ # The user's display name should be updated.
+ request, channel = self.make_request(
+ "GET", "/profile/%s/displayname" % (self.banned_user_id,)
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.json_body["displayname"], new_display_name)
+
+ # But the display name in the room should not be.
+ message_handler = self.hs.get_message_handler()
+ event = self.get_success(
+ message_handler.get_room_data(
+ self.banned_user_id, room_id, "m.room.member", self.banned_user_id,
+ )
+ )
+ self.assertEqual(
+ event.content, {"membership": "join", "displayname": original_display_name}
+ )
+
+ def test_room_displayname(self):
+ """Changes to state events for a room should be processed, but not end up in the room."""
+ original_display_name = "banned"
+ new_display_name = "new name"
+
+ # Join a room.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ # The update should succeed.
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/state/m.room.member/%s"
+ % (room_id, self.banned_user_id),
+ {"membership": "join", "displayname": new_display_name},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertIn("event_id", channel.json_body)
+
+ # The display name in the room should not be changed.
+ message_handler = self.hs.get_message_handler()
+ event = self.get_success(
+ message_handler.get_room_data(
+ self.banned_user_id, room_id, "m.room.member", self.banned_user_id,
+ )
+ )
+ self.assertEqual(
+ event.content, {"membership": "join", "displayname": original_display_name}
+ )
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 68c4a6a8f7..0a567b032f 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -21,13 +21,13 @@
import json
from urllib import parse as urlparse
-from mock import Mock, patch
+from mock import Mock
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.handlers.pagination import PurgeStatus
from synapse.rest.client.v1 import directory, login, profile, room
-from synapse.rest.client.v2_alpha import account, room_upgrade_rest_servlet
+from synapse.rest.client.v2_alpha import account
from synapse.types import JsonDict, RoomAlias, UserID
from synapse.util.stringutils import random_string
@@ -2060,158 +2060,3 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
"""An alias which does not point to the room raises a SynapseError."""
self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400)
self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)
-
-
-# To avoid the tests timing out don't add a delay to "annoy the requester".
-@patch("random.randint", new=lambda a, b: 0)
-class ShadowBannedTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets_for_client_rest_resource,
- directory.register_servlets,
- login.register_servlets,
- room.register_servlets,
- room_upgrade_rest_servlet.register_servlets,
- ]
-
- def prepare(self, reactor, clock, homeserver):
- self.banned_user_id = self.register_user("banned", "test")
- self.banned_access_token = self.login("banned", "test")
-
- self.store = self.hs.get_datastore()
-
- self.get_success(
- self.store.db_pool.simple_update(
- table="users",
- keyvalues={"name": self.banned_user_id},
- updatevalues={"shadow_banned": True},
- desc="shadow_ban",
- )
- )
-
- self.other_user_id = self.register_user("otheruser", "pass")
- self.other_access_token = self.login("otheruser", "pass")
-
- def test_invite(self):
- """Invites from shadow-banned users don't actually get sent."""
-
- # The create works fine.
- room_id = self.helper.create_room_as(
- self.banned_user_id, tok=self.banned_access_token
- )
-
- # Inviting the user completes successfully.
- self.helper.invite(
- room=room_id,
- src=self.banned_user_id,
- tok=self.banned_access_token,
- targ=self.other_user_id,
- )
-
- # But the user wasn't actually invited.
- invited_rooms = self.get_success(
- self.store.get_invited_rooms_for_local_user(self.other_user_id)
- )
- self.assertEqual(invited_rooms, [])
-
- def test_invite_3pid(self):
- """Ensure that a 3PID invite does not attempt to contact the identity server."""
- identity_handler = self.hs.get_handlers().identity_handler
- identity_handler.lookup_3pid = Mock(
- side_effect=AssertionError("This should not get called")
- )
-
- # The create works fine.
- room_id = self.helper.create_room_as(
- self.banned_user_id, tok=self.banned_access_token
- )
-
- # Inviting the user completes successfully.
- request, channel = self.make_request(
- "POST",
- "/rooms/%s/invite" % (room_id,),
- {"id_server": "test", "medium": "email", "address": "test@test.test"},
- access_token=self.banned_access_token,
- )
- self.render(request)
- self.assertEquals(200, channel.code, channel.result)
-
- # This should have raised an error earlier, but double check this wasn't called.
- identity_handler.lookup_3pid.assert_not_called()
-
- def test_create_room(self):
- """Invitations during a room creation should be discarded, but the room still gets created."""
- # The room creation is successful.
- request, channel = self.make_request(
- "POST",
- "/_matrix/client/r0/createRoom",
- {"visibility": "public", "invite": [self.other_user_id]},
- access_token=self.banned_access_token,
- )
- self.render(request)
- self.assertEquals(200, channel.code, channel.result)
- room_id = channel.json_body["room_id"]
-
- # But the user wasn't actually invited.
- invited_rooms = self.get_success(
- self.store.get_invited_rooms_for_local_user(self.other_user_id)
- )
- self.assertEqual(invited_rooms, [])
-
- # Since a real room was created, the other user should be able to join it.
- self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
-
- # Both users should be in the room.
- users = self.get_success(self.store.get_users_in_room(room_id))
- self.assertCountEqual(users, ["@banned:test", "@otheruser:test"])
-
- def test_message(self):
- """Messages from shadow-banned users don't actually get sent."""
-
- room_id = self.helper.create_room_as(
- self.other_user_id, tok=self.other_access_token
- )
-
- # The user should be in the room.
- self.helper.join(room_id, self.banned_user_id, tok=self.banned_access_token)
-
- # Sending a message should complete successfully.
- result = self.helper.send_event(
- room_id=room_id,
- type=EventTypes.Message,
- content={"msgtype": "m.text", "body": "with right label"},
- tok=self.banned_access_token,
- )
- self.assertIn("event_id", result)
- event_id = result["event_id"]
-
- latest_events = self.get_success(
- self.store.get_latest_event_ids_in_room(room_id)
- )
- self.assertNotIn(event_id, latest_events)
-
- def test_upgrade(self):
- """A room upgrade should fail, but look like it succeeded."""
-
- # The create works fine.
- room_id = self.helper.create_room_as(
- self.banned_user_id, tok=self.banned_access_token
- )
-
- request, channel = self.make_request(
- "POST",
- "/_matrix/client/r0/rooms/%s/upgrade" % (room_id,),
- {"new_version": "6"},
- access_token=self.banned_access_token,
- )
- self.render(request)
- self.assertEquals(200, channel.code, channel.result)
- # A new room_id should be returned.
- self.assertIn("replacement_room", channel.json_body)
-
- new_room_id = channel.json_body["replacement_room"]
-
- # It doesn't really matter what API we use here, we just want to assert
- # that the room doesn't exist.
- summary = self.get_success(self.store.get_room_summary(new_room_id))
- # The summary should be empty since the room doesn't exist.
- self.assertEqual(summary, {})
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 13bcac743a..bf22540d99 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -97,8 +97,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
- value = yield self.datastore.db_pool.simple_select_one_onecol(
- table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
+ value = yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_select_one_onecol(
+ table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
+ )
)
self.assertEquals("Value", value)
@@ -111,10 +113,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.fetchone.return_value = (1, 2, 3)
- ret = yield self.datastore.db_pool.simple_select_one(
- table="tablename",
- keyvalues={"keycol": "TheKey"},
- retcols=["colA", "colB", "colC"],
+ ret = yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_select_one(
+ table="tablename",
+ keyvalues={"keycol": "TheKey"},
+ retcols=["colA", "colB", "colC"],
+ )
)
self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret)
@@ -127,11 +131,13 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 0
self.mock_txn.fetchone.return_value = None
- ret = yield self.datastore.db_pool.simple_select_one(
- table="tablename",
- keyvalues={"keycol": "Not here"},
- retcols=["colA"],
- allow_none=True,
+ ret = yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_select_one(
+ table="tablename",
+ keyvalues={"keycol": "Not here"},
+ retcols=["colA"],
+ allow_none=True,
+ )
)
self.assertFalse(ret)
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 43639ca286..080761d1d2 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -271,7 +271,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
# Pump the reactor repeatedly so that the background updates have a
# chance to run.
- self.pump(10 * 60)
+ self.pump(20)
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 87ed8f8cd1..34ae8c9da7 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -38,7 +38,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
self.store.store_device("user_id", "device_id", "display_name")
)
- res = yield self.store.get_device("user_id", "device_id")
+ res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertDictContainsSubset(
{
"user_id": "user_id",
@@ -111,12 +111,12 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
self.store.store_device("user_id", "device_id", "display_name 1")
)
- res = yield self.store.get_device("user_id", "device_id")
+ res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do a no-op first
yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
- res = yield self.store.get_device("user_id", "device_id")
+ res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do the update
@@ -127,7 +127,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
)
# check it worked
- res = yield self.store.get_device("user_id", "device_id")
+ res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 2", res["display_name"])
@defer.inlineCallbacks
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 16a32cb819..7a38022e71 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -40,7 +40,12 @@ class ProfileStoreTestCase(unittest.TestCase):
)
self.assertEquals(
- "Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart))
+ "Frank",
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_displayname(self.u_frank.localpart)
+ )
+ ),
)
@defer.inlineCallbacks
@@ -55,5 +60,9 @@ class ProfileStoreTestCase(unittest.TestCase):
self.assertEquals(
"http://my.site/here",
- (yield self.store.get_profile_avatar_url(self.u_frank.localpart)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_avatar_url(self.u_frank.localpart)
+ )
+ ),
)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 58f827d8d3..70c55cd650 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -53,7 +53,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
"user_type": None,
"deactivated": 0,
},
- (yield self.store.get_user_by_id(self.user_id)),
+ (yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))),
)
@defer.inlineCallbacks
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index d07b985a8e..bc8400f240 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -54,12 +54,14 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"is_public": True,
},
- (yield self.store.get_room(self.room.to_string())),
+ (yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))),
)
@defer.inlineCallbacks
def test_get_room_unknown_room(self):
- self.assertIsNone((yield self.store.get_room("!uknown:test")),)
+ self.assertIsNone(
+ (yield defer.ensureDeferred(self.store.get_room("!uknown:test")))
+ )
@defer.inlineCallbacks
def test_get_room_with_stats(self):
@@ -69,12 +71,22 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"public": True,
},
- (yield self.store.get_room_with_stats(self.room.to_string())),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_room_with_stats(self.room.to_string())
+ )
+ ),
)
@defer.inlineCallbacks
def test_get_room_with_stats_unknown_room(self):
- self.assertIsNone((yield self.store.get_room_with_stats("!uknown:test")),)
+ self.assertIsNone(
+ (
+ yield defer.ensureDeferred(
+ self.store.get_room_with_stats("!uknown:test")
+ )
+ ),
+ )
class RoomEventsStoreTestCase(unittest.TestCase):
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index d98fe8754d..12ccc1f53e 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -87,7 +87,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN)
- self.pump(20)
+ self.pump()
self.assertTrue("_known_servers_count" not in self.store.__dict__.keys())
@@ -101,7 +101,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# Initialises to 1 -- itself
self.assertEqual(self.store._known_servers_count, 1)
- self.pump(20)
+ self.pump()
# No rooms have been joined, so technically the SQL returns 0, but it
# will still say it knows about itself.
@@ -111,7 +111,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN)
- self.pump(20)
+ self.pump(1)
# It now knows about Charlie's server.
self.assertEqual(self.store._known_servers_count, 2)
diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index bc42ffce88..5f46ed0cef 100644
--- a/tests/util/test_retryutils.py
+++ b/tests/util/test_retryutils.py
@@ -91,7 +91,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
#
# one more go, with success
#
- self.pump(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0)
+ self.reactor.advance(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0)
limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
self.pump(1)
diff --git a/tox.ini b/tox.ini
index f8ecd1aa69..6329ba286a 100644
--- a/tox.ini
+++ b/tox.ini
@@ -172,58 +172,8 @@ deps =
{[base]deps}
mypy==0.782
mypy-zope
-env =
- MYPYPATH = stubs/
extras = all
-commands = mypy \
- synapse/api \
- synapse/appservice \
- synapse/config \
- synapse/event_auth.py \
- synapse/events/builder.py \
- synapse/events/spamcheck.py \
- synapse/federation \
- synapse/handlers/auth.py \
- synapse/handlers/cas_handler.py \
- synapse/handlers/directory.py \
- synapse/handlers/federation.py \
- synapse/handlers/identity.py \
- synapse/handlers/message.py \
- synapse/handlers/oidc_handler.py \
- synapse/handlers/presence.py \
- synapse/handlers/room.py \
- synapse/handlers/room_member.py \
- synapse/handlers/room_member_worker.py \
- synapse/handlers/saml_handler.py \
- synapse/handlers/sync.py \
- synapse/handlers/ui_auth \
- synapse/http/server.py \
- synapse/http/site.py \
- synapse/logging/ \
- synapse/metrics \
- synapse/module_api \
- synapse/notifier.py \
- synapse/push/pusherpool.py \
- synapse/push/push_rule_evaluator.py \
- synapse/replication \
- synapse/rest \
- synapse/server.py \
- synapse/server_notices \
- synapse/spam_checker_api \
- synapse/state \
- synapse/storage/databases/main/ui_auth.py \
- synapse/storage/database.py \
- synapse/storage/engines \
- synapse/storage/state.py \
- synapse/storage/util \
- synapse/streams \
- synapse/types.py \
- synapse/util/caches/stream_change_cache.py \
- synapse/util/metrics.py \
- tests/replication \
- tests/test_utils \
- tests/rest/client/v2_alpha/test_auth.py \
- tests/util/test_stream_change_cache.py
+commands = mypy
# To find all folders that pass mypy you run:
#
|