diff --git a/changelog.d/7839.docker b/changelog.d/7839.docker
new file mode 100644
index 0000000000..cdf3c9631c
--- /dev/null
+++ b/changelog.d/7839.docker
@@ -0,0 +1 @@
+Base docker image on Debian Buster rather than Alpine Linux. Contributed by @maquis196.
diff --git a/changelog.d/7842.feature b/changelog.d/7842.feature
new file mode 100644
index 0000000000..727deb01c9
--- /dev/null
+++ b/changelog.d/7842.feature
@@ -0,0 +1 @@
+Add an admin API to list the users in a room. Contributed by Awesome Technologies Innovationslabor GmbH.
diff --git a/changelog.d/7849.misc b/changelog.d/7849.misc
new file mode 100644
index 0000000000..e3296418c1
--- /dev/null
+++ b/changelog.d/7849.misc
@@ -0,0 +1 @@
+Consistently use `db_to_json` to convert from database values to JSON objects.
diff --git a/changelog.d/7855.feature b/changelog.d/7855.feature
new file mode 100644
index 0000000000..2b6a9f0e71
--- /dev/null
+++ b/changelog.d/7855.feature
@@ -0,0 +1 @@
+Add experimental support for running multiple pusher workers.
diff --git a/changelog.d/7858.misc b/changelog.d/7858.misc
new file mode 100644
index 0000000000..8f0fc2de74
--- /dev/null
+++ b/changelog.d/7858.misc
@@ -0,0 +1 @@
+The default value of `filter_timeline_limit` was changed from -1 (no limit) to 100.
diff --git a/changelog.d/7859.bugfix b/changelog.d/7859.bugfix
new file mode 100644
index 0000000000..19cff4b061
--- /dev/null
+++ b/changelog.d/7859.bugfix
@@ -0,0 +1 @@
+Fix a bug which allowed empty rooms to be rejoined over federation.
diff --git a/changelog.d/7860.misc b/changelog.d/7860.misc
new file mode 100644
index 0000000000..fdd48b955c
--- /dev/null
+++ b/changelog.d/7860.misc
@@ -0,0 +1 @@
+Convert _base, profile, and _receipts handlers to async/await.
diff --git a/changelog.d/7861.misc b/changelog.d/7861.misc
new file mode 100644
index 0000000000..ada616c62f
--- /dev/null
+++ b/changelog.d/7861.misc
@@ -0,0 +1 @@
+Optimise queueing of inbound replication commands.
diff --git a/changelog.d/7866.bugfix b/changelog.d/7866.bugfix
new file mode 100644
index 0000000000..6b5c3c4eca
--- /dev/null
+++ b/changelog.d/7866.bugfix
@@ -0,0 +1 @@
+Fix 'Unable to find a suitable guest user ID' error when using multiple client_reader workers.
diff --git a/changelog.d/7868.misc b/changelog.d/7868.misc
new file mode 100644
index 0000000000..eadef5e4c2
--- /dev/null
+++ b/changelog.d/7868.misc
@@ -0,0 +1 @@
+Convert synapse.app and federation client to async/await.
diff --git a/changelog.d/7869.feature b/changelog.d/7869.feature
new file mode 100644
index 0000000000..1982049a52
--- /dev/null
+++ b/changelog.d/7869.feature
@@ -0,0 +1 @@
+Add experimental support for moving typing off master.
diff --git a/changelog.d/7871.misc b/changelog.d/7871.misc
new file mode 100644
index 0000000000..4d398a9f3a
--- /dev/null
+++ b/changelog.d/7871.misc
@@ -0,0 +1 @@
+Convert device handler to async/await.
diff --git a/changelog.d/7872.bugfix b/changelog.d/7872.bugfix
new file mode 100644
index 0000000000..b21f8e1f14
--- /dev/null
+++ b/changelog.d/7872.bugfix
@@ -0,0 +1 @@
+Fix a long standing bug where the tracing of async functions with opentracing was broken.
diff --git a/changelog.d/7880.bugfix b/changelog.d/7880.bugfix
new file mode 100644
index 0000000000..356add0996
--- /dev/null
+++ b/changelog.d/7880.bugfix
@@ -0,0 +1 @@
+Fix "TypeError in `synapse.notifier`" exceptions.
diff --git a/changelog.d/7881.misc b/changelog.d/7881.misc
new file mode 100644
index 0000000000..6799117099
--- /dev/null
+++ b/changelog.d/7881.misc
@@ -0,0 +1 @@
+Change "unknown room version" logging from 'error' to 'warning'.
diff --git a/changelog.d/7882.misc b/changelog.d/7882.misc
new file mode 100644
index 0000000000..9002749335
--- /dev/null
+++ b/changelog.d/7882.misc
@@ -0,0 +1 @@
+Stop using `device_max_stream_id` table and just use `device_inbox.stream_id`.
diff --git a/changelog.d/7885.doc b/changelog.d/7885.doc
new file mode 100644
index 0000000000..cbe9de4082
--- /dev/null
+++ b/changelog.d/7885.doc
@@ -0,0 +1 @@
+Provide instructions on using `register_new_matrix_user` via docker.
diff --git a/changelog.d/7888.misc b/changelog.d/7888.misc
new file mode 100644
index 0000000000..5328d2dcca
--- /dev/null
+++ b/changelog.d/7888.misc
@@ -0,0 +1 @@
+Remove Ubuntu Eoan from the list of `.deb` packages that we build as it is now end-of-life. Contributed by @gary-kim.
diff --git a/changelog.d/7889.doc b/changelog.d/7889.doc
new file mode 100644
index 0000000000..d91f62fd39
--- /dev/null
+++ b/changelog.d/7889.doc
@@ -0,0 +1 @@
+Change the sample config postgres user section to use `synapse_user` instead of `synapse` to align with the documentation.
\ No newline at end of file
diff --git a/changelog.d/7890.misc b/changelog.d/7890.misc
new file mode 100644
index 0000000000..8c127084bc
--- /dev/null
+++ b/changelog.d/7890.misc
@@ -0,0 +1 @@
+Fix typo in generated config file. Contributed by @ThiefMaster.
diff --git a/changelog.d/7892.misc b/changelog.d/7892.misc
new file mode 100644
index 0000000000..ef4cfa04fd
--- /dev/null
+++ b/changelog.d/7892.misc
@@ -0,0 +1 @@
+Import ABC from `collections.abc` for Python 3.10 compatibility.
diff --git a/changelog.d/7895.bugfix b/changelog.d/7895.bugfix
new file mode 100644
index 0000000000..1ae7f8ca7c
--- /dev/null
+++ b/changelog.d/7895.bugfix
@@ -0,0 +1 @@
+Fix deprecation warning due to invalid escape sequences.
\ No newline at end of file
diff --git a/changelog.d/7897.misc b/changelog.d/7897.misc
new file mode 100644
index 0000000000..77772533fd
--- /dev/null
+++ b/changelog.d/7897.misc
@@ -0,0 +1,2 @@
+Remove unused functions `time_function`, `trace_function`, `get_previous_frames`
+and `get_previous_frame` from `synapse.logging.utils` module.
\ No newline at end of file
diff --git a/changelog.d/7912.misc b/changelog.d/7912.misc
new file mode 100644
index 0000000000..d619590070
--- /dev/null
+++ b/changelog.d/7912.misc
@@ -0,0 +1 @@
+Convert `RoomListHandler` to async/await.
diff --git a/changelog.d/7914.misc b/changelog.d/7914.misc
new file mode 100644
index 0000000000..710553249c
--- /dev/null
+++ b/changelog.d/7914.misc
@@ -0,0 +1 @@
+Lint the `contrib/` directory in CI and linting scripts, add `synctl` to the linting script for consistency with CI.
diff --git a/changelog.d/7919.misc b/changelog.d/7919.misc
new file mode 100644
index 0000000000..addaa35183
--- /dev/null
+++ b/changelog.d/7919.misc
@@ -0,0 +1 @@
+Use Element CSS and logo in notification emails when app name is Element.
diff --git a/contrib/cmdclient/console.py b/contrib/cmdclient/console.py
index 48da410d94..77422f5e5d 100755
--- a/contrib/cmdclient/console.py
+++ b/contrib/cmdclient/console.py
@@ -17,9 +17,6 @@
""" Starts a synapse client console. """
from __future__ import print_function
-from twisted.internet import reactor, defer, threads
-from http import TwistedHttpClient
-
import argparse
import cmd
import getpass
@@ -28,12 +25,14 @@ import shlex
import sys
import time
import urllib
-import urlparse
+from http import TwistedHttpClient
-import nacl.signing
import nacl.encoding
+import nacl.signing
+import urlparse
+from signedjson.sign import SignatureVerifyException, verify_signed_json
-from signedjson.sign import verify_signed_json, SignatureVerifyException
+from twisted.internet import defer, reactor, threads
CONFIG_JSON = "cmdclient_config.json"
@@ -493,7 +492,7 @@ class SynapseCmd(cmd.Cmd):
"list messages <roomid> from=END&to=START&limit=3"
"""
args = self._parse(line, ["type", "roomid", "qp"])
- if not "type" in args or not "roomid" in args:
+ if "type" not in args or "roomid" not in args:
print("Must specify type and room ID.")
return
if args["type"] not in ["members", "messages"]:
@@ -508,7 +507,7 @@ class SynapseCmd(cmd.Cmd):
try:
key_value = key_value_str.split("=")
qp[key_value[0]] = key_value[1]
- except:
+ except Exception:
print("Bad query param: %s" % key_value)
return
@@ -585,7 +584,7 @@ class SynapseCmd(cmd.Cmd):
parsed_url = urlparse.urlparse(args["path"])
qp.update(urlparse.parse_qs(parsed_url.query))
args["path"] = parsed_url.path
- except:
+ except Exception:
pass
reactor.callFromThread(
@@ -772,10 +771,10 @@ def main(server_url, identity_server_url, username, token, config_path):
syn_cmd.config = json.load(config)
try:
http_client.verbose = "on" == syn_cmd.config["verbose"]
- except:
+ except Exception:
pass
print("Loaded config from %s" % config_path)
- except:
+ except Exception:
pass
# Twisted-specific: Runs the command processor in Twisted's event loop
diff --git a/contrib/cmdclient/http.py b/contrib/cmdclient/http.py
index 0e101d2be5..e2534ee584 100644
--- a/contrib/cmdclient/http.py
+++ b/contrib/cmdclient/http.py
@@ -14,14 +14,14 @@
# limitations under the License.
from __future__ import print_function
-from twisted.web.client import Agent, readBody
-from twisted.web.http_headers import Headers
-from twisted.internet import defer, reactor
-
-from pprint import pformat
import json
import urllib
+from pprint import pformat
+
+from twisted.internet import defer, reactor
+from twisted.web.client import Agent, readBody
+from twisted.web.http_headers import Headers
class HttpClient(object):
diff --git a/contrib/experiments/test_messaging.py b/contrib/experiments/test_messaging.py
index 3bbbcfa1b4..a84ec4ecae 100644
--- a/contrib/experiments/test_messaging.py
+++ b/contrib/experiments/test_messaging.py
@@ -28,27 +28,24 @@ Currently assumes the local address is localhost:<port>
"""
-from synapse.federation import ReplicationHandler
-
-from synapse.federation.units import Pdu
-
-from synapse.util import origin_from_ucid
-
-from synapse.app.homeserver import SynapseHomeServer
-
-# from synapse.logging.utils import log_function
-
-from twisted.internet import reactor, defer
-from twisted.python import log
-
import argparse
+import curses.wrapper
import json
import logging
import os
import re
import cursesio
-import curses.wrapper
+
+from twisted.internet import defer, reactor
+from twisted.python import log
+
+from synapse.app.homeserver import SynapseHomeServer
+from synapse.federation import ReplicationHandler
+from synapse.federation.units import Pdu
+from synapse.util import origin_from_ucid
+
+# from synapse.logging.utils import log_function
logger = logging.getLogger("example")
@@ -75,7 +72,7 @@ class InputOutput(object):
"""
try:
- m = re.match("^join (\S+)$", line)
+ m = re.match(r"^join (\S+)$", line)
if m:
# The `sender` wants to join a room.
(room_name,) = m.groups()
@@ -84,7 +81,7 @@ class InputOutput(object):
# self.print_line("OK.")
return
- m = re.match("^invite (\S+) (\S+)$", line)
+ m = re.match(r"^invite (\S+) (\S+)$", line)
if m:
# `sender` wants to invite someone to a room
room_name, invitee = m.groups()
@@ -93,7 +90,7 @@ class InputOutput(object):
# self.print_line("OK.")
return
- m = re.match("^send (\S+) (.*)$", line)
+ m = re.match(r"^send (\S+) (.*)$", line)
if m:
# `sender` wants to message a room
room_name, body = m.groups()
@@ -102,7 +99,7 @@ class InputOutput(object):
# self.print_line("OK.")
return
- m = re.match("^backfill (\S+)$", line)
+ m = re.match(r"^backfill (\S+)$", line)
if m:
# we want to backfill a room
(room_name,) = m.groups()
@@ -201,16 +198,6 @@ class HomeServer(ReplicationHandler):
% (pdu.context, pdu.pdu_type, json.dumps(pdu.content))
)
- # def on_state_change(self, pdu):
- ##self.output.print_line("#%s (state) %s *** %s" %
- ##(pdu.context, pdu.state_key, pdu.pdu_type)
- ##)
-
- # if "joinee" in pdu.content:
- # self._on_join(pdu.context, pdu.content["joinee"])
- # elif "invitee" in pdu.content:
- # self._on_invite(pdu.origin, pdu.context, pdu.content["invitee"])
-
def _on_message(self, pdu):
""" We received a message
"""
@@ -314,7 +301,7 @@ class HomeServer(ReplicationHandler):
return self.replication_layer.backfill(dest, room_name, limit)
def _get_room_remote_servers(self, room_name):
- return [i for i in self.joined_rooms.setdefault(room_name).servers]
+ return list(self.joined_rooms.setdefault(room_name).servers)
def _get_or_create_room(self, room_name):
return self.joined_rooms.setdefault(room_name, Room(room_name))
@@ -334,7 +321,7 @@ def main(stdscr):
user = args.user
server_name = origin_from_ucid(user)
- ## Set up logging ##
+ # Set up logging
root_logger = logging.getLogger()
@@ -354,7 +341,7 @@ def main(stdscr):
observer = log.PythonLoggingObserver()
observer.start()
- ## Set up synapse server
+ # Set up synapse server
curses_stdio = cursesio.CursesStdIO(stdscr)
input_output = InputOutput(curses_stdio, user)
@@ -368,16 +355,16 @@ def main(stdscr):
input_output.set_home_server(hs)
- ## Add input_output logger
+ # Add input_output logger
io_logger = IOLoggerHandler(input_output)
io_logger.setFormatter(formatter)
root_logger.addHandler(io_logger)
- ## Start! ##
+ # Start!
try:
port = int(server_name.split(":")[1])
- except:
+ except Exception:
port = 12345
app_hs.get_http_server().start_listening(port)
diff --git a/contrib/graph/graph.py b/contrib/graph/graph.py
index 92736480eb..de33fac1c7 100644
--- a/contrib/graph/graph.py
+++ b/contrib/graph/graph.py
@@ -1,5 +1,13 @@
from __future__ import print_function
+import argparse
+import cgi
+import datetime
+import json
+
+import pydot
+import urllib2
+
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,15 +23,6 @@ from __future__ import print_function
# limitations under the License.
-import sqlite3
-import pydot
-import cgi
-import json
-import datetime
-import argparse
-import urllib2
-
-
def make_name(pdu_id, origin):
return "%s@%s" % (pdu_id, origin)
@@ -33,7 +32,7 @@ def make_graph(pdus, room, filename_prefix):
node_map = {}
origins = set()
- colors = set(("red", "green", "blue", "yellow", "purple"))
+ colors = {"red", "green", "blue", "yellow", "purple"}
for pdu in pdus:
origins.add(pdu.get("origin"))
@@ -49,7 +48,7 @@ def make_graph(pdus, room, filename_prefix):
try:
c = colors.pop()
color_map[o] = c
- except:
+ except Exception:
print("Run out of colours!")
color_map[o] = "black"
diff --git a/contrib/graph/graph2.py b/contrib/graph/graph2.py
index 4619f0e3c1..0980231e4a 100644
--- a/contrib/graph/graph2.py
+++ b/contrib/graph/graph2.py
@@ -13,12 +13,13 @@
# limitations under the License.
-import sqlite3
-import pydot
+import argparse
import cgi
-import json
import datetime
-import argparse
+import json
+import sqlite3
+
+import pydot
from synapse.events import FrozenEvent
from synapse.util.frozenutils import unfreeze
@@ -98,7 +99,7 @@ def make_graph(db_name, room_id, file_prefix, limit):
for prev_id, _ in event.prev_events:
try:
end_node = node_map[prev_id]
- except:
+ except Exception:
end_node = pydot.Node(name=prev_id, label="<<b>%s</b>>" % (prev_id,))
node_map[prev_id] = end_node
diff --git a/contrib/graph/graph3.py b/contrib/graph/graph3.py
index 3154638520..91db98e7ef 100644
--- a/contrib/graph/graph3.py
+++ b/contrib/graph/graph3.py
@@ -1,5 +1,15 @@
from __future__ import print_function
+import argparse
+import cgi
+import datetime
+
+import pydot
+import simplejson as json
+
+from synapse.events import FrozenEvent
+from synapse.util.frozenutils import unfreeze
+
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,16 +25,6 @@ from __future__ import print_function
# limitations under the License.
-import pydot
-import cgi
-import simplejson as json
-import datetime
-import argparse
-
-from synapse.events import FrozenEvent
-from synapse.util.frozenutils import unfreeze
-
-
def make_graph(file_name, room_id, file_prefix, limit):
print("Reading lines")
with open(file_name) as f:
@@ -106,7 +106,7 @@ def make_graph(file_name, room_id, file_prefix, limit):
for prev_id, _ in event.prev_events:
try:
end_node = node_map[prev_id]
- except:
+ except Exception:
end_node = pydot.Node(name=prev_id, label="<<b>%s</b>>" % (prev_id,))
node_map[prev_id] = end_node
diff --git a/contrib/jitsimeetbridge/jitsimeetbridge.py b/contrib/jitsimeetbridge/jitsimeetbridge.py
index 67fb2cd1a7..69aa74bd34 100644
--- a/contrib/jitsimeetbridge/jitsimeetbridge.py
+++ b/contrib/jitsimeetbridge/jitsimeetbridge.py
@@ -12,15 +12,15 @@ npm install jquery jsdom
"""
from __future__ import print_function
-import gevent
-import grequests
-from BeautifulSoup import BeautifulSoup
import json
-import urllib
import subprocess
import time
-# ACCESS_TOKEN="" #
+import gevent
+import grequests
+from BeautifulSoup import BeautifulSoup
+
+ACCESS_TOKEN = ""
MATRIXBASE = "https://matrix.org/_matrix/client/api/v1/"
MYUSERNAME = "@davetest:matrix.org"
diff --git a/contrib/scripts/kick_users.py b/contrib/scripts/kick_users.py
index f57e6e7d25..372dbd9e4f 100755
--- a/contrib/scripts/kick_users.py
+++ b/contrib/scripts/kick_users.py
@@ -1,10 +1,12 @@
#!/usr/bin/env python
from __future__ import print_function
-from argparse import ArgumentParser
+
import json
-import requests
import sys
import urllib
+from argparse import ArgumentParser
+
+import requests
try:
raw_input
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 093e89af6c..8b3a4246a5 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -16,35 +16,31 @@ ARG PYTHON_VERSION=3.7
###
### Stage 0: builder
###
-FROM docker.io/python:${PYTHON_VERSION}-alpine3.11 as builder
+FROM docker.io/python:${PYTHON_VERSION}-slim as builder
# install the OS build deps
-RUN apk add \
- build-base \
- libffi-dev \
- libjpeg-turbo-dev \
- libwebp-dev \
- libressl-dev \
- libxslt-dev \
- linux-headers \
- postgresql-dev \
- zlib-dev
-# build things which have slow build steps, before we copy synapse, so that
-# the layer can be cached.
-#
-# (we really just care about caching a wheel here, as the "pip install" below
-# will install them again.)
+RUN apt-get update && apt-get install -y \
+ build-essential \
+ libpq-dev \
+ && rm -rf /var/lib/apt/lists/*
+# Build dependencies that are not available as wheels, to speed up rebuilds
RUN pip install --prefix="/install" --no-warn-script-location \
- cryptography \
- msgpack-python \
- pillow \
- pynacl
+ frozendict \
+ jaeger-client \
+ opentracing \
+ prometheus-client \
+ psycopg2 \
+ pycparser \
+ pyrsistent \
+ pyyaml \
+ simplejson \
+ threadloop \
+ thrift
# now install synapse and all of the python deps to /install.
-
COPY synapse /synapse/synapse/
COPY scripts /synapse/scripts/
COPY MANIFEST.in README.rst setup.py synctl /synapse/
@@ -56,20 +52,13 @@ RUN pip install --prefix="/install" --no-warn-script-location \
### Stage 1: runtime
###
-FROM docker.io/python:${PYTHON_VERSION}-alpine3.11
+FROM docker.io/python:${PYTHON_VERSION}-slim
-# xmlsec is required for saml support
-RUN apk add --no-cache --virtual .runtime_deps \
- libffi \
- libjpeg-turbo \
- libwebp \
- libressl \
- libxslt \
- libpq \
- zlib \
- su-exec \
- tzdata \
- xmlsec
+RUN apt-get update && apt-get install -y \
+ libpq5 \
+ xmlsec1 \
+ gosu \
+ && rm -rf /var/lib/apt/lists/*
COPY --from=builder /install /usr/local
COPY ./docker/start.py /start.py
diff --git a/docker/README.md b/docker/README.md
index 8c337149ca..008a9ff708 100644
--- a/docker/README.md
+++ b/docker/README.md
@@ -94,6 +94,21 @@ The following environment variables are supported in run mode:
* `UID`, `GID`: the user and group id to run Synapse as. Defaults to `991`, `991`.
* `TZ`: the [timezone](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones) the container will run with. Defaults to `UTC`.
+## Generating an (admin) user
+
+After synapse is running, you may wish to create a user via `register_new_matrix_user`.
+
+This requires a `registration_shared_secret` to be set in your config file. Synapse
+must be restarted to pick up this change.
+
+You can then call the script:
+
+```
+docker exec -it synapse register_new_matrix_user http://localhost:8008 -c /data/homeserver.yaml --help
+```
+
+Remember to remove the `registration_shared_secret` and restart if you no-longer need it.
+
## TLS support
The default configuration exposes a single HTTP port: http://localhost:8008. It
diff --git a/docker/start.py b/docker/start.py
index 2a25c9380e..9f08134158 100755
--- a/docker/start.py
+++ b/docker/start.py
@@ -120,7 +120,7 @@ def generate_config_from_template(config_dir, config_path, environ, ownership):
if ownership is not None:
subprocess.check_output(["chown", "-R", ownership, "/data"])
- args = ["su-exec", ownership] + args
+ args = ["gosu", ownership] + args
subprocess.check_output(args)
@@ -172,8 +172,8 @@ def run_generate_config(environ, ownership):
# make sure that synapse has perms to write to the data dir.
subprocess.check_output(["chown", ownership, data_dir])
- args = ["su-exec", ownership] + args
- os.execv("/sbin/su-exec", args)
+ args = ["gosu", ownership] + args
+ os.execv("/usr/sbin/gosu", args)
else:
os.execv("/usr/local/bin/python", args)
@@ -189,7 +189,7 @@ def main(args, environ):
ownership = "{}:{}".format(desired_uid, desired_gid)
if ownership is None:
- log("Will not perform chmod/su-exec as UserID already matches request")
+ log("Will not perform chmod/gosu as UserID already matches request")
# In generate mode, generate a configuration and missing keys, then exit
if mode == "generate":
@@ -236,8 +236,8 @@ running with 'migrate_config'. See the README for more details.
args = ["python", "-m", synapse_worker, "--config-path", config_path]
if ownership is not None:
- args = ["su-exec", ownership] + args
- os.execv("/sbin/su-exec", args)
+ args = ["gosu", ownership] + args
+ os.execv("/usr/sbin/gosu", args)
else:
os.execv("/usr/local/bin/python", args)
diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md
index 3f26adc16c..15b83e9824 100644
--- a/docs/admin_api/rooms.md
+++ b/docs/admin_api/rooms.md
@@ -319,11 +319,43 @@ Response:
}
```
+# Room Members API
+
+The Room Members admin API allows server admins to get a list of all members of a room.
+
+The response includes the following fields:
+
+* `members` - A list of all the members that are present in the room, represented by their ids.
+* `total` - Total number of members in the room.
+
+## Usage
+
+A standard request:
+
+```
+GET /_synapse/admin/v1/rooms/<room_id>/members
+
+{}
+```
+
+Response:
+
+```
+{
+ "members": [
+ "@foo:matrix.org",
+ "@bar:matrix.org",
+ "@foobar:matrix.org
+ ],
+ "total": 3
+}
+```
+
# Delete Room API
The Delete Room admin API allows server admins to remove rooms from server
and block these rooms.
-It is a combination and improvement of "[Shutdown room](shutdown_room.md)"
+It is a combination and improvement of "[Shutdown room](shutdown_room.md)"
and "[Purge room](purge_room.md)" API.
Shuts down a room. Moves all local users and room aliases automatically to a
diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md
index 131990001a..7bfb96eff6 100644
--- a/docs/reverse_proxy.md
+++ b/docs/reverse_proxy.md
@@ -38,6 +38,11 @@ the reverse proxy and the homeserver.
server {
listen 443 ssl;
listen [::]:443 ssl;
+
+ # For the federation port
+ listen 8448 ssl default_server;
+ listen [::]:8448 ssl default_server;
+
server_name matrix.example.com;
location /_matrix {
@@ -48,17 +53,6 @@ server {
client_max_body_size 10M;
}
}
-
-server {
- listen 8448 ssl default_server;
- listen [::]:8448 ssl default_server;
- server_name example.com;
-
- location / {
- proxy_pass http://localhost:8008;
- proxy_set_header X-Forwarded-For $remote_addr;
- }
-}
```
**NOTE**: Do not add a path after the port in `proxy_pass`, otherwise nginx will
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 5ed44e8a3a..e21864047a 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -102,7 +102,9 @@ pid_file: DATADIR/homeserver.pid
#gc_thresholds: [700, 10, 10]
# Set the limit on the returned events in the timeline in the get
-# and sync operations. The default value is -1, means no upper limit.
+# and sync operations. The default value is 100. -1 means no upper limit.
+#
+# Uncomment the following to increase the limit to 5000.
#
#filter_timeline_limit: 5000
@@ -146,7 +148,7 @@ pid_file: DATADIR/homeserver.pid
# names: a list of names of HTTP resources. See below for a list of
# valid resource names.
#
-# compress: set to true to enable HTTP comression for this resource.
+# compress: set to true to enable HTTP compression for this resource.
#
# additional_resources: Only valid for an 'http' listener. A map of
# additional endpoints which should be loaded via dynamic modules.
@@ -751,7 +753,7 @@ caches:
#database:
# name: psycopg2
# args:
-# user: synapse
+# user: synapse_user
# password: secretpassword
# database: synapse
# host: localhost
diff --git a/scripts-dev/build_debian_packages b/scripts-dev/build_debian_packages
index e6f4bd1dca..d055cf3287 100755
--- a/scripts-dev/build_debian_packages
+++ b/scripts-dev/build_debian_packages
@@ -24,7 +24,6 @@ DISTS = (
"debian:sid",
"ubuntu:xenial",
"ubuntu:bionic",
- "ubuntu:eoan",
"ubuntu:focal",
)
diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh
index 66b0568858..0647993658 100755
--- a/scripts-dev/lint.sh
+++ b/scripts-dev/lint.sh
@@ -11,7 +11,7 @@ if [ $# -ge 1 ]
then
files=$*
else
- files="synapse tests scripts-dev scripts"
+ files="synapse tests scripts-dev scripts contrib synctl"
fi
echo "Linting these locations: $files"
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index 0ebffb04a5..b21b8d573d 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -49,6 +49,7 @@ from synapse.storage.data_stores.main.media_repository import (
from synapse.storage.data_stores.main.profile import ProfileStore
from synapse.storage.data_stores.main.registration import (
RegistrationBackgroundUpdateStore,
+ find_max_generated_user_id_localpart,
)
from synapse.storage.data_stores.main.room import RoomBackgroundUpdateStore
from synapse.storage.data_stores.main.roommember import RoomMemberBackgroundUpdateStore
@@ -624,8 +625,10 @@ class Porter(object):
)
)
- # Step 5. Do final post-processing
+ # Step 5. Set up sequences
+ self.progress.set_state("Setting up sequence generators")
await self._setup_state_group_id_seq()
+ await self._setup_user_id_seq()
self.progress.done()
except Exception as e:
@@ -795,6 +798,13 @@ class Porter(object):
return self.postgres_store.db.runInteraction("setup_state_group_id_seq", r)
+ def _setup_user_id_seq(self):
+ def r(txn):
+ next_id = find_max_generated_user_id_localpart(txn) + 1
+ txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,))
+
+ return self.postgres_store.db.runInteraction("setup_user_id_seq", r)
+
##############################################
# The following is simply UI stuff
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index e90695f026..c1b76d827b 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -21,7 +21,7 @@ from typing import Dict, Iterable, Optional, Set
from typing_extensions import ContextManager
-from twisted.internet import address, defer, reactor
+from twisted.internet import address, reactor
import synapse
import synapse.events
@@ -111,6 +111,7 @@ from synapse.rest.client.v1.room import (
RoomSendEventRestServlet,
RoomStateEventRestServlet,
RoomStateRestServlet,
+ RoomTypingRestServlet,
)
from synapse.rest.client.v1.voip import VoipRestServlet
from synapse.rest.client.v2_alpha import groups, sync, user_directory
@@ -374,9 +375,8 @@ class GenericWorkerPresence(BasePresenceHandler):
return _user_syncing()
- @defer.inlineCallbacks
- def notify_from_replication(self, states, stream_id):
- parties = yield get_interested_parties(self.store, states)
+ async def notify_from_replication(self, states, stream_id):
+ parties = await get_interested_parties(self.store, states)
room_ids_to_states, users_to_states = parties
self.notifier.on_new_event(
@@ -386,8 +386,7 @@ class GenericWorkerPresence(BasePresenceHandler):
users=users_to_states.keys(),
)
- @defer.inlineCallbacks
- def process_replication_rows(self, token, rows):
+ async def process_replication_rows(self, token, rows):
states = [
UserPresenceState(
row.user_id,
@@ -405,7 +404,7 @@ class GenericWorkerPresence(BasePresenceHandler):
self.user_to_current_state[state.user_id] = state
stream_id = token
- yield self.notify_from_replication(states, stream_id)
+ await self.notify_from_replication(states, stream_id)
def get_currently_syncing_users_for_replication(self) -> Iterable[str]:
return [
@@ -451,37 +450,6 @@ class GenericWorkerPresence(BasePresenceHandler):
await self._bump_active_client(user_id=user_id)
-class GenericWorkerTyping(object):
- def __init__(self, hs):
- self._latest_room_serial = 0
- self._reset()
-
- def _reset(self):
- """
- Reset the typing handler's data caches.
- """
- # map room IDs to serial numbers
- self._room_serials = {}
- # map room IDs to sets of users currently typing
- self._room_typing = {}
-
- def process_replication_rows(self, token, rows):
- if self._latest_room_serial > token:
- # The master has gone backwards. To prevent inconsistent data, just
- # clear everything.
- self._reset()
-
- # Set the latest serial token to whatever the server gave us.
- self._latest_room_serial = token
-
- for row in rows:
- self._room_serials[row.room_id] = token
- self._room_typing[row.room_id] = row.user_ids
-
- def get_current_token(self) -> int:
- return self._latest_room_serial
-
-
class GenericWorkerSlavedStore(
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
# rather than going via the correct worker.
@@ -558,6 +526,7 @@ class GenericWorkerServer(HomeServer):
KeyUploadServlet(self).register(resource)
AccountDataServlet(self).register(resource)
RoomAccountDataServlet(self).register(resource)
+ RoomTypingRestServlet(self).register(resource)
sync.register_servlets(self, resource)
events.register_servlets(self, resource)
@@ -669,9 +638,6 @@ class GenericWorkerServer(HomeServer):
def build_presence_handler(self):
return GenericWorkerPresence(self)
- def build_typing_handler(self):
- return GenericWorkerTyping(self)
-
class GenericWorkerReplicationHandler(ReplicationDataHandler):
def __init__(self, hs):
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 09291d86ad..ec7401f911 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -483,8 +483,7 @@ class SynapseService(service.Service):
_stats_process = []
-@defer.inlineCallbacks
-def phone_stats_home(hs, stats, stats_process=_stats_process):
+async def phone_stats_home(hs, stats, stats_process=_stats_process):
logger.info("Gathering stats for reporting")
now = int(hs.get_clock().time())
uptime = int(now - hs.start_time)
@@ -522,28 +521,28 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
stats["python_version"] = "{}.{}.{}".format(
version.major, version.minor, version.micro
)
- stats["total_users"] = yield hs.get_datastore().count_all_users()
+ stats["total_users"] = await hs.get_datastore().count_all_users()
- total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users()
+ total_nonbridged_users = await hs.get_datastore().count_nonbridged_users()
stats["total_nonbridged_users"] = total_nonbridged_users
- daily_user_type_results = yield hs.get_datastore().count_daily_user_type()
+ daily_user_type_results = await hs.get_datastore().count_daily_user_type()
for name, count in daily_user_type_results.items():
stats["daily_user_type_" + name] = count
- room_count = yield hs.get_datastore().get_room_count()
+ room_count = await hs.get_datastore().get_room_count()
stats["total_room_count"] = room_count
- stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
- stats["monthly_active_users"] = yield hs.get_datastore().count_monthly_users()
- stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms()
- stats["daily_messages"] = yield hs.get_datastore().count_daily_messages()
+ stats["daily_active_users"] = await hs.get_datastore().count_daily_users()
+ stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users()
+ stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms()
+ stats["daily_messages"] = await hs.get_datastore().count_daily_messages()
- r30_results = yield hs.get_datastore().count_r30_users()
+ r30_results = await hs.get_datastore().count_r30_users()
for name, count in r30_results.items():
stats["r30_users_" + name] = count
- daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()
+ daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
stats["daily_sent_messages"] = daily_sent_messages
stats["cache_factor"] = hs.config.caches.global_factor
stats["event_cache_size"] = hs.config.caches.event_cache_size
@@ -558,7 +557,7 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
try:
- yield hs.get_proxied_http_client().put_json(
+ await hs.get_proxied_http_client().put_json(
hs.config.report_stats_endpoint, stats
)
except Exception as e:
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index f2830c609d..34a2370e67 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -19,10 +19,12 @@ import argparse
import errno
import os
from collections import OrderedDict
+from hashlib import sha256
from io import open as io_open
from textwrap import dedent
-from typing import Any, MutableMapping, Optional
+from typing import Any, List, MutableMapping, Optional
+import attr
import yaml
@@ -718,4 +720,36 @@ def find_config_files(search_paths):
return config_files
-__all__ = ["Config", "RootConfig"]
+@attr.s
+class ShardedWorkerHandlingConfig:
+ """Algorithm for choosing which instance is responsible for handling some
+ sharded work.
+
+ For example, the federation senders use this to determine which instances
+ handles sending stuff to a given destination (which is used as the `key`
+ below).
+ """
+
+ instances = attr.ib(type=List[str])
+
+ def should_handle(self, instance_name: str, key: str) -> bool:
+ """Whether this instance is responsible for handling the given key.
+ """
+
+ # If multiple instances are not defined we always return true.
+ if not self.instances or len(self.instances) == 1:
+ return True
+
+ # We shard by taking the hash, modulo it by the number of instances and
+ # then checking whether this instance matches the instance at that
+ # index.
+ #
+ # (Technically this introduces some bias and is not entirely uniform,
+ # but since the hash is so large the bias is ridiculously small).
+ dest_hash = sha256(key.encode("utf8")).digest()
+ dest_int = int.from_bytes(dest_hash, byteorder="little")
+ remainder = dest_int % (len(self.instances))
+ return self.instances[remainder] == instance_name
+
+
+__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 9e576060d4..eb911e8f9f 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -137,3 +137,8 @@ class Config:
def read_config_files(config_files: List[str]): ...
def find_config_files(search_paths: List[str]): ...
+
+class ShardedWorkerHandlingConfig:
+ instances: List[str]
+ def __init__(self, instances: List[str]) -> None: ...
+ def should_handle(self, instance_name: str, key: str) -> bool: ...
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 1064c2697b..62bccd9ef5 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -55,7 +55,7 @@ DEFAULT_CONFIG = """\
#database:
# name: psycopg2
# args:
-# user: synapse
+# user: synapse_user
# password: secretpassword
# database: synapse
# host: localhost
diff --git a/synapse/config/federation.py b/synapse/config/federation.py
index 7782ab4c9d..82ff9664de 100644
--- a/synapse/config/federation.py
+++ b/synapse/config/federation.py
@@ -13,42 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from hashlib import sha256
-from typing import List, Optional
+from typing import Optional
-import attr
from netaddr import IPSet
-from ._base import Config, ConfigError
-
-
-@attr.s
-class ShardedFederationSendingConfig:
- """Algorithm for choosing which federation sender instance is responsible
- for which destionation host.
- """
-
- instances = attr.ib(type=List[str])
-
- def should_send_to(self, instance_name: str, destination: str) -> bool:
- """Whether this instance is responsible for sending transcations for
- the given host.
- """
-
- # If multiple federation senders are not defined we always return true.
- if not self.instances or len(self.instances) == 1:
- return True
-
- # We shard by taking the hash, modulo it by the number of federation
- # senders and then checking whether this instance matches the instance
- # at that index.
- #
- # (Technically this introduces some bias and is not entirely uniform, but
- # since the hash is so large the bias is ridiculously small).
- dest_hash = sha256(destination.encode("utf8")).digest()
- dest_int = int.from_bytes(dest_hash, byteorder="little")
- remainder = dest_int % (len(self.instances))
- return self.instances[remainder] == instance_name
+from ._base import Config, ConfigError, ShardedWorkerHandlingConfig
class FederationConfig(Config):
@@ -61,7 +30,7 @@ class FederationConfig(Config):
self.send_federation = config.get("send_federation", True)
federation_sender_instances = config.get("federation_sender_instances") or []
- self.federation_shard_config = ShardedFederationSendingConfig(
+ self.federation_shard_config = ShardedWorkerHandlingConfig(
federation_sender_instances
)
diff --git a/synapse/config/push.py b/synapse/config/push.py
index 6f2b3a7faa..a1f3752c8a 100644
--- a/synapse/config/push.py
+++ b/synapse/config/push.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import Config
+from ._base import Config, ShardedWorkerHandlingConfig
class PushConfig(Config):
@@ -24,6 +24,9 @@ class PushConfig(Config):
push_config = config.get("push", {})
self.push_include_content = push_config.get("include_content", True)
+ pusher_instances = config.get("pusher_instances") or []
+ self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances)
+
# There was a a 'redact_content' setting but mistakenly read from the
# 'email'section'. Check for the flag in the 'push' section, and log,
# but do not honour it to avoid nasty surprises when people upgrade.
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 9f406e471e..35687f427e 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -207,7 +207,7 @@ class ServerConfig(Config):
# errors when attempting to search for messages.
self.enable_search = config.get("enable_search", True)
- self.filter_timeline_limit = config.get("filter_timeline_limit", -1)
+ self.filter_timeline_limit = config.get("filter_timeline_limit", 100)
# Whether we should block invites sent to users on this server
# (other than those sent by local server admins)
@@ -699,7 +699,9 @@ class ServerConfig(Config):
#gc_thresholds: [700, 10, 10]
# Set the limit on the returned events in the timeline in the get
- # and sync operations. The default value is -1, means no upper limit.
+ # and sync operations. The default value is 100. -1 means no upper limit.
+ #
+ # Uncomment the following to increase the limit to 5000.
#
#filter_timeline_limit: 5000
@@ -743,7 +745,7 @@ class ServerConfig(Config):
# names: a list of names of HTTP resources. See below for a list of
# valid resource names.
#
- # compress: set to true to enable HTTP comression for this resource.
+ # compress: set to true to enable HTTP compression for this resource.
#
# additional_resources: Only valid for an 'http' listener. A map of
# additional endpoints which should be loaded via dynamic modules.
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index dbc661630c..2574cd3aa1 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -34,9 +34,11 @@ class WriterLocations:
Attributes:
events: The instance that writes to the event and backfill streams.
+ events: The instance that writes to the typing stream.
"""
events = attr.ib(default="master", type=str)
+ typing = attr.ib(default="master", type=str)
class WorkerConfig(Config):
@@ -93,16 +95,15 @@ class WorkerConfig(Config):
writers = config.get("stream_writers") or {}
self.writers = WriterLocations(**writers)
- # Check that the configured writer for events also appears in
+ # Check that the configured writer for events and typing also appears in
# `instance_map`.
- if (
- self.writers.events != "master"
- and self.writers.events not in self.instance_map
- ):
- raise ConfigError(
- "Instance %r is configured to write events but does not appear in `instance_map` config."
- % (self.writers.events,)
- )
+ for stream in ("events", "typing"):
+ instance = getattr(self.writers, stream)
+ if instance != "master" and instance not in self.instance_map:
+ raise ConfigError(
+ "Instance %r is configured to write %s but does not appear in `instance_map` config."
+ % (instance, stream)
+ )
def read_arguments(self, args):
# We support a bunch of command line arguments that override options in
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index f6b507977f..11f0d34ec8 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import collections
+import collections.abc
import re
from typing import Any, Mapping, Union
@@ -424,7 +424,7 @@ def copy_power_levels_contents(
Raises:
TypeError if the input does not look like a valid power levels event content
"""
- if not isinstance(old_power_levels, collections.Mapping):
+ if not isinstance(old_power_levels, collections.abc.Mapping):
raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,))
power_levels = {}
@@ -434,7 +434,7 @@ def copy_power_levels_contents(
power_levels[k] = v
continue
- if isinstance(v, collections.Mapping):
+ if isinstance(v, collections.abc.Mapping):
power_levels[k] = h = {}
for k1, v1 in v.items():
# we should only have one level of nesting
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index a37cc9cb4a..994e6c8d5a 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -374,29 +374,26 @@ class FederationClient(FederationBase):
"""
deferreds = self._check_sigs_and_hashes(room_version, pdus)
- @defer.inlineCallbacks
- def handle_check_result(pdu: EventBase, deferred: Deferred):
+ async def handle_check_result(pdu: EventBase, deferred: Deferred):
try:
- res = yield make_deferred_yieldable(deferred)
+ res = await make_deferred_yieldable(deferred)
except SynapseError:
res = None
if not res:
# Check local db.
- res = yield self.store.get_event(
+ res = await self.store.get_event(
pdu.event_id, allow_rejected=True, allow_none=True
)
if not res and pdu.origin != origin:
try:
- res = yield defer.ensureDeferred(
- self.get_pdu(
- destinations=[pdu.origin],
- event_id=pdu.event_id,
- room_version=room_version,
- outlier=outlier,
- timeout=10000,
- )
+ res = await self.get_pdu(
+ destinations=[pdu.origin],
+ event_id=pdu.event_id,
+ room_version=room_version,
+ outlier=outlier,
+ timeout=10000,
)
except SynapseError:
pass
@@ -995,24 +992,25 @@ class FederationClient(FederationBase):
raise RuntimeError("Failed to send to any server.")
- @defer.inlineCallbacks
- def get_room_complexity(self, destination, room_id):
+ async def get_room_complexity(
+ self, destination: str, room_id: str
+ ) -> Optional[dict]:
"""
Fetch the complexity of a remote room from another server.
Args:
- destination (str): The remote server
- room_id (str): The room ID to ask about.
+ destination: The remote server
+ room_id: The room ID to ask about.
Returns:
- Deferred[dict] or Deferred[None]: Dict contains the complexity
- metric versions, while None means we could not fetch the complexity.
+ Dict contains the complexity metric versions, while None means we
+ could not fetch the complexity.
"""
try:
- complexity = yield self.transport_layer.get_room_complexity(
+ complexity = await self.transport_layer.get_room_complexity(
destination=destination, room_id=room_id
)
- defer.returnValue(complexity)
+ return complexity
except CodeMessageException as e:
# We didn't manage to get it -- probably a 404. We are okay if other
# servers don't give it to us.
@@ -1029,4 +1027,4 @@ class FederationClient(FederationBase):
# If we don't manage to find it, return None. It's not an error if a
# server doesn't give it to us.
- defer.returnValue(None)
+ return None
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 8c53330c49..23625ba995 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -15,7 +15,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ Dict,
+ List,
+ Match,
+ Optional,
+ Tuple,
+ Union,
+)
from canonicaljson import json
from prometheus_client import Counter, Histogram
@@ -56,6 +67,9 @@ from synapse.util import glob_to_regex, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
# when processing incoming transactions, we try to handle multiple rooms in
# parallel, up to this limit.
TRANSACTION_CONCURRENCY_LIMIT = 10
@@ -768,11 +782,30 @@ class FederationHandlerRegistry(object):
query type for incoming federation traffic.
"""
- def __init__(self):
- self.edu_handlers = {}
- self.query_handlers = {}
+ def __init__(self, hs: "HomeServer"):
+ self.config = hs.config
+ self.http_client = hs.get_simple_http_client()
+ self.clock = hs.get_clock()
+ self._instance_name = hs.get_instance_name()
- def register_edu_handler(self, edu_type: str, handler: Callable[[str, dict], None]):
+ # These are safe to load in monolith mode, but will explode if we try
+ # and use them. However we have guards before we use them to ensure that
+ # we don't route to ourselves, and in monolith mode that will always be
+ # the case.
+ self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs)
+ self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs)
+
+ self.edu_handlers = (
+ {}
+ ) # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
+ self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]]
+
+ # Map from type to instance name that we should route EDU handling to.
+ self._edu_type_to_instance = {} # type: Dict[str, str]
+
+ def register_edu_handler(
+ self, edu_type: str, handler: Callable[[str, dict], Awaitable[None]]
+ ):
"""Sets the handler callable that will be used to handle an incoming
federation EDU of the given type.
@@ -809,66 +842,56 @@ class FederationHandlerRegistry(object):
self.query_handlers[query_type] = handler
+ def register_instance_for_edu(self, edu_type: str, instance_name: str):
+ """Register that the EDU handler is on a different instance than master.
+ """
+ self._edu_type_to_instance[edu_type] = instance_name
+
async def on_edu(self, edu_type: str, origin: str, content: dict):
+ if not self.config.use_presence and edu_type == "m.presence":
+ return
+
+ # Check if we have a handler on this instance
handler = self.edu_handlers.get(edu_type)
- if not handler:
- logger.warning("No handler registered for EDU type %s", edu_type)
+ if handler:
+ with start_active_span_from_edu(content, "handle_edu"):
+ try:
+ await handler(origin, content)
+ except SynapseError as e:
+ logger.info("Failed to handle edu %r: %r", edu_type, e)
+ except Exception:
+ logger.exception("Failed to handle edu %r", edu_type)
return
- with start_active_span_from_edu(content, "handle_edu"):
+ # Check if we can route it somewhere else that isn't us
+ route_to = self._edu_type_to_instance.get(edu_type, "master")
+ if route_to != self._instance_name:
try:
- await handler(origin, content)
+ await self._send_edu(
+ instance_name=route_to,
+ edu_type=edu_type,
+ origin=origin,
+ content=content,
+ )
except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception:
logger.exception("Failed to handle edu %r", edu_type)
-
- def on_query(self, query_type: str, args: dict) -> defer.Deferred:
- handler = self.query_handlers.get(query_type)
- if not handler:
- logger.warning("No handler registered for query type %s", query_type)
- raise NotFoundError("No handler for Query type '%s'" % (query_type,))
-
- return handler(args)
-
-
-class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
- """A FederationHandlerRegistry for worker processes.
-
- When receiving EDU or queries it will check if an appropriate handler has
- been registered on the worker, if there isn't one then it calls off to the
- master process.
- """
-
- def __init__(self, hs):
- self.config = hs.config
- self.http_client = hs.get_simple_http_client()
- self.clock = hs.get_clock()
-
- self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs)
- self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs)
-
- super(ReplicationFederationHandlerRegistry, self).__init__()
-
- async def on_edu(self, edu_type: str, origin: str, content: dict):
- """Overrides FederationHandlerRegistry
- """
- if not self.config.use_presence and edu_type == "m.presence":
return
- handler = self.edu_handlers.get(edu_type)
- if handler:
- return await super(ReplicationFederationHandlerRegistry, self).on_edu(
- edu_type, origin, content
- )
-
- return await self._send_edu(edu_type=edu_type, origin=origin, content=content)
+ # Oh well, let's just log and move on.
+ logger.warning("No handler registered for EDU type %s", edu_type)
async def on_query(self, query_type: str, args: dict):
- """Overrides FederationHandlerRegistry
- """
handler = self.query_handlers.get(query_type)
if handler:
return await handler(args)
- return await self._get_query_client(query_type=query_type, args=args)
+ # Check if we can route it somewhere else that isn't us
+ if self._instance_name == "master":
+ return await self._get_query_client(query_type=query_type, args=args)
+
+ # Uh oh, no handler! Let's raise an exception so the request returns an
+ # error.
+ logger.warning("No handler registered for query type %s", query_type)
+ raise NotFoundError("No handler for Query type '%s'" % (query_type,))
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 4b63a0755f..b328a4df09 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -197,7 +197,7 @@ class FederationSender(object):
destinations = {
d
for d in destinations
- if self._federation_shard_config.should_send_to(
+ if self._federation_shard_config.should_handle(
self._instance_name, d
)
}
@@ -335,7 +335,7 @@ class FederationSender(object):
d
for d in domains
if d != self.server_name
- and self._federation_shard_config.should_send_to(self._instance_name, d)
+ and self._federation_shard_config.should_handle(self._instance_name, d)
]
if not domains:
return
@@ -441,7 +441,7 @@ class FederationSender(object):
for destination in destinations:
if destination == self.server_name:
continue
- if not self._federation_shard_config.should_send_to(
+ if not self._federation_shard_config.should_handle(
self._instance_name, destination
):
continue
@@ -460,7 +460,7 @@ class FederationSender(object):
if destination == self.server_name:
continue
- if not self._federation_shard_config.should_send_to(
+ if not self._federation_shard_config.should_handle(
self._instance_name, destination
):
continue
@@ -486,7 +486,7 @@ class FederationSender(object):
logger.info("Not sending EDU to ourselves")
return
- if not self._federation_shard_config.should_send_to(
+ if not self._federation_shard_config.should_handle(
self._instance_name, destination
):
return
@@ -507,7 +507,7 @@ class FederationSender(object):
edu: edu to send
key: clobbering key for this edu
"""
- if not self._federation_shard_config.should_send_to(
+ if not self._federation_shard_config.should_handle(
self._instance_name, edu.destination
):
return
@@ -523,7 +523,7 @@ class FederationSender(object):
logger.warning("Not sending device update to ourselves")
return
- if not self._federation_shard_config.should_send_to(
+ if not self._federation_shard_config.should_handle(
self._instance_name, destination
):
return
@@ -541,7 +541,7 @@ class FederationSender(object):
logger.warning("Not waking up ourselves")
return
- if not self._federation_shard_config.should_send_to(
+ if not self._federation_shard_config.should_handle(
self._instance_name, destination
):
return
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 6402136e8a..3436741783 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -78,7 +78,7 @@ class PerDestinationQueue(object):
self._federation_shard_config = hs.config.federation.federation_shard_config
self._should_send_on_this_instance = True
- if not self._federation_shard_config.should_send_to(
+ if not self._federation_shard_config.should_handle(
self._instance_name, destination
):
# We don't raise an exception here to avoid taking out any other
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 1478ee03a5..5da69c2c49 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -20,8 +20,6 @@ import logging
import re
from typing import Optional, Tuple, Type
-from twisted.internet.defer import maybeDeferred
-
import synapse
from synapse.api.errors import Codes, FederationDeniedError, SynapseError
from synapse.api.room_versions import RoomVersions
@@ -796,12 +794,8 @@ class PublicRoomList(BaseFederationServlet):
# zero is a special value which corresponds to no limit.
limit = None
- data = await maybeDeferred(
- self.handler.get_local_public_room_list,
- limit,
- since_token,
- network_tuple=network_tuple,
- from_federation=True,
+ data = await self.handler.get_local_public_room_list(
+ limit, since_token, network_tuple=network_tuple, from_federation=True
)
return 200, data
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 6a4944467a..ba2bf99800 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
import synapse.state
import synapse.storage
import synapse.types
@@ -66,8 +64,7 @@ class BaseHandler(object):
self.event_builder_factory = hs.get_event_builder_factory()
- @defer.inlineCallbacks
- def ratelimit(self, requester, update=True, is_admin_redaction=False):
+ async def ratelimit(self, requester, update=True, is_admin_redaction=False):
"""Ratelimits requests.
Args:
@@ -99,7 +96,7 @@ class BaseHandler(object):
burst_count = self._rc_message.burst_count
# Check if there is a per user override in the DB.
- override = yield self.store.get_ratelimit_for_user(user_id)
+ override = await self.store.get_ratelimit_for_user(user_id)
if override:
# If overridden with a null Hz then ratelimiting has been entirely
# disabled for the user
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 31346b56c3..db417d60de 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -15,9 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Dict, Optional
-
-from twisted.internet import defer
+from typing import Any, Dict, List, Optional
from synapse.api import errors
from synapse.api.constants import EventTypes
@@ -57,21 +55,20 @@ class DeviceWorkerHandler(BaseHandler):
self._auth_handler = hs.get_auth_handler()
@trace
- @defer.inlineCallbacks
- def get_devices_by_user(self, user_id):
+ async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]:
"""
Retrieve the given user's devices
Args:
- user_id (str):
+ user_id: The user ID to query for devices.
Returns:
- defer.Deferred: list[dict[str, X]]: info on each device
+ info on each device
"""
set_tag("user_id", user_id)
- device_map = yield self.store.get_devices_by_user(user_id)
+ device_map = await self.store.get_devices_by_user(user_id)
- ips = yield self.store.get_last_client_ip_by_device(user_id, device_id=None)
+ ips = await self.store.get_last_client_ip_by_device(user_id, device_id=None)
devices = list(device_map.values())
for device in devices:
@@ -81,24 +78,23 @@ class DeviceWorkerHandler(BaseHandler):
return devices
@trace
- @defer.inlineCallbacks
- def get_device(self, user_id, device_id):
+ async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
""" Retrieve the given device
Args:
- user_id (str):
- device_id (str):
+ user_id: The user to get the device from
+ device_id: The device to fetch.
Returns:
- defer.Deferred: dict[str, X]: info on the device
+ info on the device
Raises:
errors.NotFoundError: if the device was not found
"""
try:
- device = yield self.store.get_device(user_id, device_id)
+ device = await self.store.get_device(user_id, device_id)
except errors.StoreError:
raise errors.NotFoundError
- ips = yield self.store.get_last_client_ip_by_device(user_id, device_id)
+ ips = await self.store.get_last_client_ip_by_device(user_id, device_id)
_update_device_from_client_ips(device, ips)
set_tag("device", device)
@@ -106,10 +102,9 @@ class DeviceWorkerHandler(BaseHandler):
return device
- @measure_func("device.get_user_ids_changed")
@trace
- @defer.inlineCallbacks
- def get_user_ids_changed(self, user_id, from_token):
+ @measure_func("device.get_user_ids_changed")
+ async def get_user_ids_changed(self, user_id, from_token):
"""Get list of users that have had the devices updated, or have newly
joined a room, that `user_id` may be interested in.
@@ -120,13 +115,13 @@ class DeviceWorkerHandler(BaseHandler):
set_tag("user_id", user_id)
set_tag("from_token", from_token)
- now_room_key = yield self.store.get_room_events_max_id()
+ now_room_key = await self.store.get_room_events_max_id()
- room_ids = yield self.store.get_rooms_for_user(user_id)
+ room_ids = await self.store.get_rooms_for_user(user_id)
# First we check if any devices have changed for users that we share
# rooms with.
- users_who_share_room = yield self.store.get_users_who_share_room_with_user(
+ users_who_share_room = await self.store.get_users_who_share_room_with_user(
user_id
)
@@ -135,14 +130,14 @@ class DeviceWorkerHandler(BaseHandler):
# Always tell the user about their own devices
tracked_users.add(user_id)
- changed = yield self.store.get_users_whose_devices_changed(
+ changed = await self.store.get_users_whose_devices_changed(
from_token.device_list_key, tracked_users
)
# Then work out if any users have since joined
rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
- member_events = yield self.store.get_membership_changes_for_user(
+ member_events = await self.store.get_membership_changes_for_user(
user_id, from_token.room_key, now_room_key
)
rooms_changed.update(event.room_id for event in member_events)
@@ -152,7 +147,7 @@ class DeviceWorkerHandler(BaseHandler):
possibly_changed = set(changed)
possibly_left = set()
for room_id in rooms_changed:
- current_state_ids = yield self.store.get_current_state_ids(room_id)
+ current_state_ids = await self.store.get_current_state_ids(room_id)
# The user may have left the room
# TODO: Check if they actually did or if we were just invited.
@@ -166,7 +161,7 @@ class DeviceWorkerHandler(BaseHandler):
# Fetch the current state at the time.
try:
- event_ids = yield self.store.get_forward_extremeties_for_room(
+ event_ids = await self.store.get_forward_extremeties_for_room(
room_id, stream_ordering=stream_ordering
)
except errors.StoreError:
@@ -192,7 +187,7 @@ class DeviceWorkerHandler(BaseHandler):
continue
# mapping from event_id -> state_dict
- prev_state_ids = yield self.state_store.get_state_ids_for_events(event_ids)
+ prev_state_ids = await self.state_store.get_state_ids_for_events(event_ids)
# Check if we've joined the room? If so we just blindly add all the users to
# the "possibly changed" users.
@@ -238,11 +233,10 @@ class DeviceWorkerHandler(BaseHandler):
return result
- @defer.inlineCallbacks
- def on_federation_query_user_devices(self, user_id):
- stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
- master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master")
- self_signing_key = yield self.store.get_e2e_cross_signing_key(
+ async def on_federation_query_user_devices(self, user_id):
+ stream_id, devices = await self.store.get_devices_with_keys_by_user(user_id)
+ master_key = await self.store.get_e2e_cross_signing_key(user_id, "master")
+ self_signing_key = await self.store.get_e2e_cross_signing_key(
user_id, "self_signing"
)
@@ -271,8 +265,7 @@ class DeviceHandler(DeviceWorkerHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room)
- @defer.inlineCallbacks
- def check_device_registered(
+ async def check_device_registered(
self, user_id, device_id, initial_device_display_name=None
):
"""
@@ -290,13 +283,13 @@ class DeviceHandler(DeviceWorkerHandler):
str: device id (generated if none was supplied)
"""
if device_id is not None:
- new_device = yield self.store.store_device(
+ new_device = await self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
)
if new_device:
- yield self.notify_device_update(user_id, [device_id])
+ await self.notify_device_update(user_id, [device_id])
return device_id
# if the device id is not specified, we'll autogen one, but loop a few
@@ -304,33 +297,29 @@ class DeviceHandler(DeviceWorkerHandler):
attempts = 0
while attempts < 5:
device_id = stringutils.random_string(10).upper()
- new_device = yield self.store.store_device(
+ new_device = await self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
)
if new_device:
- yield self.notify_device_update(user_id, [device_id])
+ await self.notify_device_update(user_id, [device_id])
return device_id
attempts += 1
raise errors.StoreError(500, "Couldn't generate a device ID.")
@trace
- @defer.inlineCallbacks
- def delete_device(self, user_id, device_id):
+ async def delete_device(self, user_id: str, device_id: str) -> None:
""" Delete the given device
Args:
- user_id (str):
- device_id (str):
-
- Returns:
- defer.Deferred:
+ user_id: The user to delete the device from.
+ device_id: The device to delete.
"""
try:
- yield self.store.delete_device(user_id, device_id)
+ await self.store.delete_device(user_id, device_id)
except errors.StoreError as e:
if e.code == 404:
# no match
@@ -342,49 +331,40 @@ class DeviceHandler(DeviceWorkerHandler):
else:
raise
- yield defer.ensureDeferred(
- self._auth_handler.delete_access_tokens_for_user(
- user_id, device_id=device_id
- )
+ await self._auth_handler.delete_access_tokens_for_user(
+ user_id, device_id=device_id
)
- yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id)
+ await self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id)
- yield self.notify_device_update(user_id, [device_id])
+ await self.notify_device_update(user_id, [device_id])
@trace
- @defer.inlineCallbacks
- def delete_all_devices_for_user(self, user_id, except_device_id=None):
+ async def delete_all_devices_for_user(
+ self, user_id: str, except_device_id: Optional[str] = None
+ ) -> None:
"""Delete all of the user's devices
Args:
- user_id (str):
- except_device_id (str|None): optional device id which should not
- be deleted
-
- Returns:
- defer.Deferred:
+ user_id: The user to remove all devices from
+ except_device_id: optional device id which should not be deleted
"""
- device_map = yield self.store.get_devices_by_user(user_id)
+ device_map = await self.store.get_devices_by_user(user_id)
device_ids = list(device_map)
if except_device_id is not None:
device_ids = [d for d in device_ids if d != except_device_id]
- yield self.delete_devices(user_id, device_ids)
+ await self.delete_devices(user_id, device_ids)
- @defer.inlineCallbacks
- def delete_devices(self, user_id, device_ids):
+ async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
""" Delete several devices
Args:
- user_id (str):
- device_ids (List[str]): The list of device IDs to delete
-
- Returns:
- defer.Deferred:
+ user_id: The user to delete devices from.
+ device_ids: The list of device IDs to delete
"""
try:
- yield self.store.delete_devices(user_id, device_ids)
+ await self.store.delete_devices(user_id, device_ids)
except errors.StoreError as e:
if e.code == 404:
# no match
@@ -397,28 +377,22 @@ class DeviceHandler(DeviceWorkerHandler):
# Delete access tokens and e2e keys for each device. Not optimised as it is not
# considered as part of a critical path.
for device_id in device_ids:
- yield defer.ensureDeferred(
- self._auth_handler.delete_access_tokens_for_user(
- user_id, device_id=device_id
- )
+ await self._auth_handler.delete_access_tokens_for_user(
+ user_id, device_id=device_id
)
- yield self.store.delete_e2e_keys_by_device(
+ await self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id
)
- yield self.notify_device_update(user_id, device_ids)
+ await self.notify_device_update(user_id, device_ids)
- @defer.inlineCallbacks
- def update_device(self, user_id, device_id, content):
+ async def update_device(self, user_id: str, device_id: str, content: dict) -> None:
""" Update the given device
Args:
- user_id (str):
- device_id (str):
- content (dict): body of update request
-
- Returns:
- defer.Deferred:
+ user_id: The user to update devices of.
+ device_id: The device to update.
+ content: body of update request
"""
# Reject a new displayname which is too long.
@@ -431,10 +405,10 @@ class DeviceHandler(DeviceWorkerHandler):
)
try:
- yield self.store.update_device(
+ await self.store.update_device(
user_id, device_id, new_display_name=new_display_name
)
- yield self.notify_device_update(user_id, [device_id])
+ await self.notify_device_update(user_id, [device_id])
except errors.StoreError as e:
if e.code == 404:
raise errors.NotFoundError()
@@ -443,12 +417,15 @@ class DeviceHandler(DeviceWorkerHandler):
@trace
@measure_func("notify_device_update")
- @defer.inlineCallbacks
- def notify_device_update(self, user_id, device_ids):
+ async def notify_device_update(self, user_id, device_ids):
"""Notify that a user's device(s) has changed. Pokes the notifier, and
remote servers if the user is local.
"""
- users_who_share_room = yield self.store.get_users_who_share_room_with_user(
+ if not device_ids:
+ # No changes to notify about, so this is a no-op.
+ return
+
+ users_who_share_room = await self.store.get_users_who_share_room_with_user(
user_id
)
@@ -459,20 +436,24 @@ class DeviceHandler(DeviceWorkerHandler):
set_tag("target_hosts", hosts)
- position = yield self.store.add_device_change_to_streams(
+ position = await self.store.add_device_change_to_streams(
user_id, device_ids, list(hosts)
)
+ if not position:
+ # This should only happen if there are no updates, so we bail.
+ return
+
for device_id in device_ids:
logger.debug(
"Notifying about update %r/%r, ID: %r", user_id, device_id, position
)
- room_ids = yield self.store.get_rooms_for_user(user_id)
+ room_ids = await self.store.get_rooms_for_user(user_id)
# specify the user ID too since the user should always get their own device list
# updates, even if they aren't in any rooms.
- yield self.notifier.on_new_event(
+ self.notifier.on_new_event(
"device_list_key", position, users=[user_id], rooms=room_ids
)
@@ -484,29 +465,29 @@ class DeviceHandler(DeviceWorkerHandler):
self.federation_sender.send_device_messages(host)
log_kv({"message": "sent device update to host", "host": host})
- @defer.inlineCallbacks
- def notify_user_signature_update(self, from_user_id, user_ids):
+ async def notify_user_signature_update(
+ self, from_user_id: str, user_ids: List[str]
+ ) -> None:
"""Notify a user that they have made new signatures of other users.
Args:
- from_user_id (str): the user who made the signature
- user_ids (list[str]): the users IDs that have new signatures
+ from_user_id: the user who made the signature
+ user_ids: the users IDs that have new signatures
"""
- position = yield self.store.add_user_signature_change_to_streams(
+ position = await self.store.add_user_signature_change_to_streams(
from_user_id, user_ids
)
self.notifier.on_new_event("device_list_key", position, users=[from_user_id])
- @defer.inlineCallbacks
- def user_left_room(self, user, room_id):
+ async def user_left_room(self, user, room_id):
user_id = user.to_string()
- room_ids = yield self.store.get_rooms_for_user(user_id)
+ room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids:
# We no longer share rooms with this user, so we'll no longer
# receive device updates. Mark this in DB.
- yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
+ await self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
def _update_device_from_client_ips(device, client_ips):
@@ -549,8 +530,7 @@ class DeviceListUpdater(object):
)
@trace
- @defer.inlineCallbacks
- def incoming_device_list_update(self, origin, edu_content):
+ async def incoming_device_list_update(self, origin, edu_content):
"""Called on incoming device list update from federation. Responsible
for parsing the EDU and adding to pending updates list.
"""
@@ -583,7 +563,7 @@ class DeviceListUpdater(object):
)
return
- room_ids = yield self.store.get_rooms_for_user(user_id)
+ room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids:
# We don't share any rooms with this user. Ignore update, as we
# probably won't get any further updates.
@@ -608,14 +588,13 @@ class DeviceListUpdater(object):
(device_id, stream_id, prev_ids, edu_content)
)
- yield self._handle_device_updates(user_id)
+ await self._handle_device_updates(user_id)
@measure_func("_incoming_device_list_update")
- @defer.inlineCallbacks
- def _handle_device_updates(self, user_id):
+ async def _handle_device_updates(self, user_id):
"Actually handle pending updates."
- with (yield self._remote_edu_linearizer.queue(user_id)):
+ with (await self._remote_edu_linearizer.queue(user_id)):
pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates:
# This can happen since we batch updates
@@ -632,7 +611,7 @@ class DeviceListUpdater(object):
# Given a list of updates we check if we need to resync. This
# happens if we've missed updates.
- resync = yield self._need_to_do_resync(user_id, pending_updates)
+ resync = await self._need_to_do_resync(user_id, pending_updates)
if logger.isEnabledFor(logging.INFO):
logger.info(
@@ -643,16 +622,16 @@ class DeviceListUpdater(object):
)
if resync:
- yield self.user_device_resync(user_id)
+ await self.user_device_resync(user_id)
else:
# Simply update the single device, since we know that is the only
# change (because of the single prev_id matching the current cache)
for device_id, stream_id, prev_ids, content in pending_updates:
- yield self.store.update_remote_device_list_cache_entry(
+ await self.store.update_remote_device_list_cache_entry(
user_id, device_id, content, stream_id
)
- yield self.device_handler.notify_device_update(
+ await self.device_handler.notify_device_update(
user_id, [device_id for device_id, _, _, _ in pending_updates]
)
@@ -660,14 +639,13 @@ class DeviceListUpdater(object):
stream_id for _, stream_id, _, _ in pending_updates
)
- @defer.inlineCallbacks
- def _need_to_do_resync(self, user_id, updates):
+ async def _need_to_do_resync(self, user_id, updates):
"""Given a list of updates for a user figure out if we need to do a full
resync, or whether we have enough data that we can just apply the delta.
"""
seen_updates = self._seen_updates.get(user_id, set())
- extremity = yield self.store.get_device_list_last_stream_id_for_remote(user_id)
+ extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)
logger.debug("Current extremity for %r: %r", user_id, extremity)
@@ -692,8 +670,7 @@ class DeviceListUpdater(object):
return False
@trace
- @defer.inlineCallbacks
- def _maybe_retry_device_resync(self):
+ async def _maybe_retry_device_resync(self):
"""Retry to resync device lists that are out of sync, except if another retry is
in progress.
"""
@@ -705,12 +682,12 @@ class DeviceListUpdater(object):
# we don't send too many requests.
self._resync_retry_in_progress = True
# Get all of the users that need resyncing.
- need_resync = yield self.store.get_user_ids_requiring_device_list_resync()
+ need_resync = await self.store.get_user_ids_requiring_device_list_resync()
# Iterate over the set of user IDs.
for user_id in need_resync:
try:
# Try to resync the current user's devices list.
- result = yield self.user_device_resync(
+ result = await self.user_device_resync(
user_id=user_id, mark_failed_as_stale=False,
)
@@ -734,16 +711,17 @@ class DeviceListUpdater(object):
# Allow future calls to retry resyncinc out of sync device lists.
self._resync_retry_in_progress = False
- @defer.inlineCallbacks
- def user_device_resync(self, user_id, mark_failed_as_stale=True):
+ async def user_device_resync(
+ self, user_id: str, mark_failed_as_stale: bool = True
+ ) -> Optional[dict]:
"""Fetches all devices for a user and updates the device cache with them.
Args:
- user_id (str): The user's id whose device_list will be updated.
- mark_failed_as_stale (bool): Whether to mark the user's device list as stale
+ user_id: The user's id whose device_list will be updated.
+ mark_failed_as_stale: Whether to mark the user's device list as stale
if the attempt to resync failed.
Returns:
- Deferred[dict]: a dict with device info as under the "devices" in the result of this
+ A dict with device info as under the "devices" in the result of this
request:
https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
"""
@@ -752,12 +730,12 @@ class DeviceListUpdater(object):
# Fetch all devices for the user.
origin = get_domain_from_id(user_id)
try:
- result = yield self.federation.query_user_devices(origin, user_id)
+ result = await self.federation.query_user_devices(origin, user_id)
except NotRetryingDestination:
if mark_failed_as_stale:
# Mark the remote user's device list as stale so we know we need to retry
# it later.
- yield self.store.mark_remote_user_device_cache_as_stale(user_id)
+ await self.store.mark_remote_user_device_cache_as_stale(user_id)
return
except (RequestSendFailed, HttpResponseException) as e:
@@ -768,7 +746,7 @@ class DeviceListUpdater(object):
if mark_failed_as_stale:
# Mark the remote user's device list as stale so we know we need to retry
# it later.
- yield self.store.mark_remote_user_device_cache_as_stale(user_id)
+ await self.store.mark_remote_user_device_cache_as_stale(user_id)
# We abort on exceptions rather than accepting the update
# as otherwise synapse will 'forget' that its device list
@@ -792,7 +770,7 @@ class DeviceListUpdater(object):
if mark_failed_as_stale:
# Mark the remote user's device list as stale so we know we need to retry
# it later.
- yield self.store.mark_remote_user_device_cache_as_stale(user_id)
+ await self.store.mark_remote_user_device_cache_as_stale(user_id)
return
log_kv({"result": result})
@@ -833,25 +811,24 @@ class DeviceListUpdater(object):
stream_id,
)
- yield self.store.update_remote_device_list_cache(user_id, devices, stream_id)
+ await self.store.update_remote_device_list_cache(user_id, devices, stream_id)
device_ids = [device["device_id"] for device in devices]
# Handle cross-signing keys.
- cross_signing_device_ids = yield self.process_cross_signing_key_update(
+ cross_signing_device_ids = await self.process_cross_signing_key_update(
user_id, master_key, self_signing_key,
)
device_ids = device_ids + cross_signing_device_ids
- yield self.device_handler.notify_device_update(user_id, device_ids)
+ await self.device_handler.notify_device_update(user_id, device_ids)
# We clobber the seen updates since we've re-synced from a given
# point.
self._seen_updates[user_id] = {stream_id}
- defer.returnValue(result)
+ return result
- @defer.inlineCallbacks
- def process_cross_signing_key_update(
+ async def process_cross_signing_key_update(
self,
user_id: str,
master_key: Optional[Dict[str, Any]],
@@ -872,14 +849,14 @@ class DeviceListUpdater(object):
device_ids = []
if master_key:
- yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
+ await self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
_, verify_key = get_verify_key_from_cross_signing_key(master_key)
# verify_key is a VerifyKey from signedjson, which uses
# .version to denote the portion of the key ID after the
# algorithm and colon, which is the device ID
device_ids.append(verify_key.version)
if self_signing_key:
- yield self.store.set_e2e_cross_signing_key(
+ await self.store.set_e2e_cross_signing_key(
user_id, "self_signing", self_signing_key
)
_, verify_key = get_verify_key_from_cross_signing_key(self_signing_key)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 52499c679d..1178af6920 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -19,7 +19,7 @@
import itertools
import logging
-from collections import Container
+from collections.abc import Container
from http import HTTPStatus
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
@@ -44,6 +44,7 @@ from synapse.api.errors import (
FederationDeniedError,
FederationError,
HttpResponseException,
+ NotFoundError,
RequestSendFailed,
SynapseError,
)
@@ -1442,10 +1443,20 @@ class FederationHandler(BaseHandler):
)
raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
- event_content = {"membership": Membership.JOIN}
-
+ # checking the room version will check that we've actually heard of the room
+ # (and return a 404 otherwise)
room_version = await self.store.get_room_version_id(room_id)
+ # now check that we are *still* in the room
+ is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
+ if not is_in_room:
+ logger.info(
+ "Got /make_join request for room %s we are no longer in", room_id,
+ )
+ raise NotFoundError("Not an active room on this server")
+
+ event_content = {"membership": Membership.JOIN}
+
builder = self.event_builder_factory.new(
room_version,
{
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index da206e1ec1..c47764a4ce 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -488,11 +488,15 @@ class EventCreationHandler(object):
try:
if "displayname" not in content:
- displayname = yield profile.get_displayname(target)
+ displayname = yield defer.ensureDeferred(
+ profile.get_displayname(target)
+ )
if displayname is not None:
content["displayname"] = displayname
if "avatar_url" not in content:
- avatar_url = yield profile.get_avatar_url(target)
+ avatar_url = yield defer.ensureDeferred(
+ profile.get_avatar_url(target)
+ )
if avatar_url is not None:
content["avatar_url"] = avatar_url
except Exception as e:
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index dd8979e750..acecb9c5db 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -15,10 +15,8 @@
# limitations under the License.
import logging
-from typing import List
-
-from six.moves import range
+from typing import List
from signedjson.sign import sign_json
from twisted.internet import defer, reactor
@@ -145,16 +143,15 @@ class BaseProfileHandler(BaseHandler):
)
raise
- @defer.inlineCallbacks
- def get_profile(self, user_id):
+ async def get_profile(self, user_id):
target_user = UserID.from_string(user_id)
if self.hs.is_mine(target_user):
try:
- displayname = yield self.store.get_profile_displayname(
+ displayname = await self.store.get_profile_displayname(
target_user.localpart
)
- avatar_url = yield self.store.get_profile_avatar_url(
+ avatar_url = await self.store.get_profile_avatar_url(
target_user.localpart
)
except StoreError as e:
@@ -165,7 +162,7 @@ class BaseProfileHandler(BaseHandler):
return {"displayname": displayname, "avatar_url": avatar_url}
else:
try:
- result = yield self.federation.make_query(
+ result = await self.federation.make_query(
destination=target_user.domain,
query_type="profile",
args={"user_id": user_id},
@@ -177,8 +174,7 @@ class BaseProfileHandler(BaseHandler):
except HttpResponseException as e:
raise e.to_synapse_error()
- @defer.inlineCallbacks
- def get_profile_from_cache(self, user_id):
+ async def get_profile_from_cache(self, user_id):
"""Get the profile information from our local cache. If the user is
ours then the profile information will always be corect. Otherwise,
it may be out of date/missing.
@@ -186,10 +182,10 @@ class BaseProfileHandler(BaseHandler):
target_user = UserID.from_string(user_id)
if self.hs.is_mine(target_user):
try:
- displayname = yield self.store.get_profile_displayname(
+ displayname = await self.store.get_profile_displayname(
target_user.localpart
)
- avatar_url = yield self.store.get_profile_avatar_url(
+ avatar_url = await self.store.get_profile_avatar_url(
target_user.localpart
)
except StoreError as e:
@@ -199,14 +195,13 @@ class BaseProfileHandler(BaseHandler):
return {"displayname": displayname, "avatar_url": avatar_url}
else:
- profile = yield self.store.get_from_remote_profile_cache(user_id)
+ profile = await self.store.get_from_remote_profile_cache(user_id)
return profile or {}
- @defer.inlineCallbacks
- def get_displayname(self, target_user):
+ async def get_displayname(self, target_user):
if self.hs.is_mine(target_user):
try:
- displayname = yield self.store.get_profile_displayname(
+ displayname = await self.store.get_profile_displayname(
target_user.localpart
)
except StoreError as e:
@@ -217,7 +212,7 @@ class BaseProfileHandler(BaseHandler):
return displayname
else:
try:
- result = yield self.federation.make_query(
+ result = await self.federation.make_query(
destination=target_user.domain,
query_type="profile",
args={"user_id": target_user.to_string(), "field": "displayname"},
@@ -334,11 +329,10 @@ class BaseProfileHandler(BaseHandler):
# start a profile replication push
run_in_background(self._replicate_profiles)
- @defer.inlineCallbacks
- def get_avatar_url(self, target_user):
+ async def get_avatar_url(self, target_user):
if self.hs.is_mine(target_user):
try:
- avatar_url = yield self.store.get_profile_avatar_url(
+ avatar_url = await self.store.get_profile_avatar_url(
target_user.localpart
)
except StoreError as e:
@@ -348,7 +342,7 @@ class BaseProfileHandler(BaseHandler):
return avatar_url
else:
try:
- result = yield self.federation.make_query(
+ result = await self.federation.make_query(
destination=target_user.domain,
query_type="profile",
args={"user_id": target_user.to_string(), "field": "avatar_url"},
@@ -455,8 +449,7 @@ class BaseProfileHandler(BaseHandler):
raise SynapseError(400, "Invalid avatar URL '%s' supplied" % mxc)
return avatar_pieces[-1]
- @defer.inlineCallbacks
- def on_profile_query(self, args):
+ async def on_profile_query(self, args):
user = UserID.from_string(args["user_id"])
if not self.hs.is_mine(user):
raise SynapseError(400, "User is not hosted on this homeserver")
@@ -466,12 +459,12 @@ class BaseProfileHandler(BaseHandler):
response = {}
try:
if just_field is None or just_field == "displayname":
- response["displayname"] = yield self.store.get_profile_displayname(
+ response["displayname"] = await self.store.get_profile_displayname(
user.localpart
)
if just_field is None or just_field == "avatar_url":
- response["avatar_url"] = yield self.store.get_profile_avatar_url(
+ response["avatar_url"] = await self.store.get_profile_avatar_url(
user.localpart
)
except StoreError as e:
@@ -506,8 +499,7 @@ class BaseProfileHandler(BaseHandler):
"Failed to update join event for room %s - %s", room_id, str(e)
)
- @defer.inlineCallbacks
- def check_profile_query_allowed(self, target_user, requester=None):
+ async def check_profile_query_allowed(self, target_user, requester=None):
"""Checks whether a profile query is allowed. If the
'require_auth_for_profile_requests' config flag is set to True and a
'requester' is provided, the query is only allowed if the two users
@@ -539,8 +531,8 @@ class BaseProfileHandler(BaseHandler):
return
try:
- requester_rooms = yield self.store.get_rooms_for_user(requester.to_string())
- target_user_rooms = yield self.store.get_rooms_for_user(
+ requester_rooms = await self.store.get_rooms_for_user(requester.to_string())
+ target_user_rooms = await self.store.get_rooms_for_user(
target_user.to_string()
)
@@ -573,25 +565,24 @@ class MasterProfileHandler(BaseProfileHandler):
"Update remote profile", self._update_remote_profile_cache
)
- @defer.inlineCallbacks
- def _update_remote_profile_cache(self):
+ async def _update_remote_profile_cache(self):
"""Called periodically to check profiles of remote users we haven't
checked in a while.
"""
- entries = yield self.store.get_remote_profile_cache_entries_that_expire(
+ entries = await self.store.get_remote_profile_cache_entries_that_expire(
last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS
)
for user_id, displayname, avatar_url in entries:
- is_subscribed = yield self.store.is_subscribed_remote_profile_for_user(
+ is_subscribed = await self.store.is_subscribed_remote_profile_for_user(
user_id
)
if not is_subscribed:
- yield self.store.maybe_delete_remote_profile_cache(user_id)
+ await self.store.maybe_delete_remote_profile_cache(user_id)
continue
try:
- profile = yield self.federation.make_query(
+ profile = await self.federation.make_query(
destination=get_domain_from_id(user_id),
query_type="profile",
args={"user_id": user_id},
@@ -600,7 +591,7 @@ class MasterProfileHandler(BaseProfileHandler):
except Exception:
logger.exception("Failed to get avatar_url")
- yield self.store.update_remote_profile_cache(
+ await self.store.update_remote_profile_cache(
user_id, displayname, avatar_url
)
continue
@@ -609,4 +600,4 @@ class MasterProfileHandler(BaseProfileHandler):
new_avatar = profile.get("avatar_url")
# We always hit update to update the last_check timestamp
- yield self.store.update_remote_profile_cache(user_id, new_name, new_avatar)
+ await self.store.update_remote_profile_cache(user_id, new_name, new_avatar)
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 8bc100db42..f922d8a545 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -14,8 +14,6 @@
# limitations under the License.
import logging
-from twisted.internet import defer
-
from synapse.handlers._base import BaseHandler
from synapse.types import ReadReceipt, get_domain_from_id
from synapse.util.async_helpers import maybe_awaitable
@@ -129,15 +127,14 @@ class ReceiptEventSource(object):
def __init__(self, hs):
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def get_new_events(self, from_key, room_ids, **kwargs):
+ async def get_new_events(self, from_key, room_ids, **kwargs):
from_key = int(from_key)
- to_key = yield self.get_current_key()
+ to_key = self.get_current_key()
if from_key == to_key:
return [], to_key
- events = yield self.store.get_linearized_receipts_for_rooms(
+ events = await self.store.get_linearized_receipts_for_rooms(
room_ids, from_key=from_key, to_key=to_key
)
@@ -146,8 +143,7 @@ class ReceiptEventSource(object):
def get_current_key(self, direction="f"):
return self.store.get_max_receipt_stream_id()
- @defer.inlineCallbacks
- def get_pagination_rows(self, user, config, key):
+ async def get_pagination_rows(self, user, config, key):
to_key = int(config.from_key)
if config.to_key:
@@ -155,8 +151,8 @@ class ReceiptEventSource(object):
else:
from_key = None
- room_ids = yield self.store.get_rooms_for_user(user.to_string())
- events = yield self.store.get_linearized_receipts_for_rooms(
+ room_ids = await self.store.get_rooms_for_user(user.to_string())
+ events = await self.store.get_linearized_receipts_for_rooms(
room_ids, from_key=from_key, to_key=to_key
)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index f223630d43..d00b9dc537 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -28,7 +28,6 @@ from synapse.replication.http.register import (
)
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID, create_requester
-from synapse.util.async_helpers import Linearizer
from ._base import BaseHandler
@@ -51,14 +50,7 @@ class RegistrationHandler(BaseHandler):
self.http_client = hs.get_simple_http_client()
self.identity_handler = self.hs.get_handlers().identity_handler
self.ratelimiter = hs.get_registration_ratelimiter()
-
- self._next_generated_user_id = None
-
self.macaroon_gen = hs.get_macaroon_generator()
-
- self._generate_user_id_linearizer = Linearizer(
- name="_generate_user_id_linearizer"
- )
self._server_notices_mxid = hs.config.server_notices_mxid
self._show_in_user_directory = self.hs.config.show_users_in_user_directory
@@ -239,7 +231,7 @@ class RegistrationHandler(BaseHandler):
if fail_count > 10:
raise SynapseError(500, "Unable to find a suitable guest user ID")
- localpart = await self._generate_user_id()
+ localpart = await self.store.generate_user_id()
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
self.check_user_id_not_appservice_exclusive(user_id)
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 5e05be6181..5dd7b28391 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -20,12 +20,10 @@ from typing import Any, Dict, Optional
import msgpack
from unpaddedbase64 import decode_base64, encode_base64
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, HttpResponseException
from synapse.types import ThirdPartyInstanceID
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached
from synapse.util.caches.response_cache import ResponseCache
from ._base import BaseHandler
@@ -47,7 +45,7 @@ class RoomListHandler(BaseHandler):
hs, "remote_room_list", timeout_ms=30 * 1000
)
- def get_local_public_room_list(
+ async def get_local_public_room_list(
self,
limit=None,
since_token=None,
@@ -72,7 +70,7 @@ class RoomListHandler(BaseHandler):
API
"""
if not self.enable_room_list_search:
- return defer.succeed({"chunk": [], "total_room_count_estimate": 0})
+ return {"chunk": [], "total_room_count_estimate": 0}
logger.info(
"Getting public room list: limit=%r, since=%r, search=%r, network=%r",
@@ -87,7 +85,7 @@ class RoomListHandler(BaseHandler):
# appservice specific lists.
logger.info("Bypassing cache as search request.")
- return self._get_public_room_list(
+ return await self._get_public_room_list(
limit,
since_token,
search_filter,
@@ -96,7 +94,7 @@ class RoomListHandler(BaseHandler):
)
key = (limit, since_token, network_tuple)
- return self.response_cache.wrap(
+ return await self.response_cache.wrap(
key,
self._get_public_room_list,
limit,
@@ -105,8 +103,7 @@ class RoomListHandler(BaseHandler):
from_federation=from_federation,
)
- @defer.inlineCallbacks
- def _get_public_room_list(
+ async def _get_public_room_list(
self,
limit: Optional[int] = None,
since_token: Optional[str] = None,
@@ -145,7 +142,7 @@ class RoomListHandler(BaseHandler):
# we request one more than wanted to see if there are more pages to come
probing_limit = limit + 1 if limit is not None else None
- results = yield self.store.get_largest_public_rooms(
+ results = await self.store.get_largest_public_rooms(
network_tuple,
search_filter,
probing_limit,
@@ -221,44 +218,44 @@ class RoomListHandler(BaseHandler):
response["chunk"] = results
- response["total_room_count_estimate"] = yield self.store.count_public_rooms(
+ response["total_room_count_estimate"] = await self.store.count_public_rooms(
network_tuple, ignore_non_federatable=from_federation
)
return response
- @cachedInlineCallbacks(num_args=1, cache_context=True)
- def generate_room_entry(
+ @cached(num_args=1, cache_context=True)
+ async def generate_room_entry(
self,
- room_id,
- num_joined_users,
+ room_id: str,
+ num_joined_users: int,
cache_context,
- with_alias=True,
- allow_private=False,
- ):
+ with_alias: bool = True,
+ allow_private: bool = False,
+ ) -> Optional[dict]:
"""Returns the entry for a room
Args:
- room_id (str): The room's ID.
- num_joined_users (int): Number of users in the room.
+ room_id: The room's ID.
+ num_joined_users: Number of users in the room.
cache_context: Information for cached responses.
- with_alias (bool): Whether to return the room's aliases in the result.
- allow_private (bool): Whether invite-only rooms should be shown.
+ with_alias: Whether to return the room's aliases in the result.
+ allow_private: Whether invite-only rooms should be shown.
Returns:
- Deferred[dict|None]: Returns a room entry as a dictionary, or None if this
+ Returns a room entry as a dictionary, or None if this
room was determined not to be shown publicly.
"""
result = {"room_id": room_id, "num_joined_members": num_joined_users}
if with_alias:
- aliases = yield self.store.get_aliases_for_room(
+ aliases = await self.store.get_aliases_for_room(
room_id, on_invalidate=cache_context.invalidate
)
if aliases:
result["aliases"] = aliases
- current_state_ids = yield self.store.get_current_state_ids(
+ current_state_ids = await self.store.get_current_state_ids(
room_id, on_invalidate=cache_context.invalidate
)
@@ -266,7 +263,7 @@ class RoomListHandler(BaseHandler):
# We're not in the room, so may as well bail out here.
return result
- event_map = yield self.store.get_events(
+ event_map = await self.store.get_events(
[
event_id
for key, event_id in current_state_ids.items()
@@ -336,8 +333,7 @@ class RoomListHandler(BaseHandler):
return result
- @defer.inlineCallbacks
- def get_remote_public_room_list(
+ async def get_remote_public_room_list(
self,
server_name,
limit=None,
@@ -356,7 +352,7 @@ class RoomListHandler(BaseHandler):
# to a locally-filtered search if we must.
try:
- res = yield self._get_remote_list_cached(
+ res = await self._get_remote_list_cached(
server_name,
limit=limit,
since_token=since_token,
@@ -381,7 +377,7 @@ class RoomListHandler(BaseHandler):
limit = None
since_token = None
- res = yield self._get_remote_list_cached(
+ res = await self._get_remote_list_cached(
server_name,
limit=limit,
since_token=since_token,
@@ -400,7 +396,7 @@ class RoomListHandler(BaseHandler):
return res
- def _get_remote_list_cached(
+ async def _get_remote_list_cached(
self,
server_name,
limit=None,
@@ -412,7 +408,7 @@ class RoomListHandler(BaseHandler):
repl_layer = self.hs.get_federation_client()
if search_filter:
# We can't cache when asking for search
- return repl_layer.get_public_rooms(
+ return await repl_layer.get_public_rooms(
server_name,
limit=limit,
since_token=since_token,
@@ -428,7 +424,7 @@ class RoomListHandler(BaseHandler):
include_all_networks,
third_party_instance_id,
)
- return self.remote_response_cache.wrap(
+ return await self.remote_response_cache.wrap(
key,
repl_layer.get_public_rooms,
server_name,
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 846ddbdc6c..a86ac0150e 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -15,15 +15,19 @@
import logging
from collections import namedtuple
-from typing import List, Tuple
+from typing import TYPE_CHECKING, List, Set, Tuple
from synapse.api.errors import AuthError, SynapseError
-from synapse.logging.context import run_in_background
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.tcp.streams import TypingStream
from synapse.types import UserID, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -39,48 +43,48 @@ FEDERATION_TIMEOUT = 60 * 1000
FEDERATION_PING_INTERVAL = 40 * 1000
-class TypingHandler(object):
- def __init__(self, hs):
+class FollowerTypingHandler:
+ """A typing handler on a different process than the writer that is updated
+ via replication.
+ """
+
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.server_name = hs.config.server_name
- self.auth = hs.get_auth()
- self.is_mine_id = hs.is_mine_id
- self.notifier = hs.get_notifier()
- self.state = hs.get_state_handler()
-
- self.hs = hs
-
self.clock = hs.get_clock()
- self.wheel_timer = WheelTimer(bucket_size=5000)
+ self.is_mine_id = hs.is_mine_id
- self.federation = hs.get_federation_sender()
+ self.federation = None
+ if hs.should_send_federation():
+ self.federation = hs.get_federation_sender()
- hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu)
+ if hs.config.worker.writers.typing != hs.get_instance_name():
+ hs.get_federation_registry().register_instance_for_edu(
+ "m.typing", hs.config.worker.writers.typing,
+ )
- hs.get_distributor().observe("user_left_room", self.user_left_room)
+ # map room IDs to serial numbers
+ self._room_serials = {}
+ # map room IDs to sets of users currently typing
+ self._room_typing = {}
- self._member_typing_until = {} # clock time we expect to stop
self._member_last_federation_poke = {}
-
+ self.wheel_timer = WheelTimer(bucket_size=5000)
self._latest_room_serial = 0
- self._reset()
-
- # caches which room_ids changed at which serials
- self._typing_stream_change_cache = StreamChangeCache(
- "TypingStreamChangeCache", self._latest_room_serial
- )
self.clock.looping_call(self._handle_timeouts, 5000)
def _reset(self):
- """
- Reset the typing handler's data caches.
+ """Reset the typing handler's data caches.
"""
# map room IDs to serial numbers
self._room_serials = {}
# map room IDs to sets of users currently typing
self._room_typing = {}
+ self._member_last_federation_poke = {}
+ self.wheel_timer = WheelTimer(bucket_size=5000)
+
def _handle_timeouts(self):
logger.debug("Checking for typing timeouts")
@@ -89,30 +93,140 @@ class TypingHandler(object):
members = set(self.wheel_timer.fetch(now))
for member in members:
- if not self.is_typing(member):
- # Nothing to do if they're no longer typing
- continue
-
- until = self._member_typing_until.get(member, None)
- if not until or until <= now:
- logger.info("Timing out typing for: %s", member.user_id)
- self._stopped_typing(member)
- continue
-
- # Check if we need to resend a keep alive over federation for this
- # user.
- if self.hs.is_mine_id(member.user_id):
- last_fed_poke = self._member_last_federation_poke.get(member, None)
- if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now:
- run_in_background(self._push_remote, member=member, typing=True)
-
- # Add a paranoia timer to ensure that we always have a timer for
- # each person typing.
- self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
+ self._handle_timeout_for_member(now, member)
+
+ def _handle_timeout_for_member(self, now: int, member: RoomMember):
+ if not self.is_typing(member):
+ # Nothing to do if they're no longer typing
+ return
+
+ # Check if we need to resend a keep alive over federation for this
+ # user.
+ if self.federation and self.is_mine_id(member.user_id):
+ last_fed_poke = self._member_last_federation_poke.get(member, None)
+ if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now:
+ run_as_background_process(
+ "typing._push_remote", self._push_remote, member=member, typing=True
+ )
+
+ # Add a paranoia timer to ensure that we always have a timer for
+ # each person typing.
+ self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
def is_typing(self, member):
return member.user_id in self._room_typing.get(member.room_id, [])
+ async def _push_remote(self, member, typing):
+ if not self.federation:
+ return
+
+ try:
+ users = await self.store.get_users_in_room(member.room_id)
+ self._member_last_federation_poke[member] = self.clock.time_msec()
+
+ now = self.clock.time_msec()
+ self.wheel_timer.insert(
+ now=now, obj=member, then=now + FEDERATION_PING_INTERVAL
+ )
+
+ for domain in {get_domain_from_id(u) for u in users}:
+ if domain != self.server_name:
+ logger.debug("sending typing update to %s", domain)
+ self.federation.build_and_send_edu(
+ destination=domain,
+ edu_type="m.typing",
+ content={
+ "room_id": member.room_id,
+ "user_id": member.user_id,
+ "typing": typing,
+ },
+ key=member,
+ )
+ except Exception:
+ logger.exception("Error pushing typing notif to remotes")
+
+ def process_replication_rows(
+ self, token: int, rows: List[TypingStream.TypingStreamRow]
+ ):
+ """Should be called whenever we receive updates for typing stream.
+ """
+
+ if self._latest_room_serial > token:
+ # The master has gone backwards. To prevent inconsistent data, just
+ # clear everything.
+ self._reset()
+
+ # Set the latest serial token to whatever the server gave us.
+ self._latest_room_serial = token
+
+ for row in rows:
+ self._room_serials[row.room_id] = token
+
+ prev_typing = set(self._room_typing.get(row.room_id, []))
+ now_typing = set(row.user_ids)
+ self._room_typing[row.room_id] = row.user_ids
+
+ run_as_background_process(
+ "_handle_change_in_typing",
+ self._handle_change_in_typing,
+ row.room_id,
+ prev_typing,
+ now_typing,
+ )
+
+ async def _handle_change_in_typing(
+ self, room_id: str, prev_typing: Set[str], now_typing: Set[str]
+ ):
+ """Process a change in typing of a room from replication, sending EDUs
+ for any local users.
+ """
+ for user_id in now_typing - prev_typing:
+ if self.is_mine_id(user_id):
+ await self._push_remote(RoomMember(room_id, user_id), True)
+
+ for user_id in prev_typing - now_typing:
+ if self.is_mine_id(user_id):
+ await self._push_remote(RoomMember(room_id, user_id), False)
+
+ def get_current_token(self):
+ return self._latest_room_serial
+
+
+class TypingWriterHandler(FollowerTypingHandler):
+ def __init__(self, hs):
+ super().__init__(hs)
+
+ assert hs.config.worker.writers.typing == hs.get_instance_name()
+
+ self.auth = hs.get_auth()
+ self.notifier = hs.get_notifier()
+
+ self.hs = hs
+
+ hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu)
+
+ hs.get_distributor().observe("user_left_room", self.user_left_room)
+
+ self._member_typing_until = {} # clock time we expect to stop
+
+ # caches which room_ids changed at which serials
+ self._typing_stream_change_cache = StreamChangeCache(
+ "TypingStreamChangeCache", self._latest_room_serial
+ )
+
+ def _handle_timeout_for_member(self, now: int, member: RoomMember):
+ super()._handle_timeout_for_member(now, member)
+
+ if not self.is_typing(member):
+ # Nothing to do if they're no longer typing
+ return
+
+ until = self._member_typing_until.get(member, None)
+ if not until or until <= now:
+ logger.info("Timing out typing for: %s", member.user_id)
+ self._stopped_typing(member)
+ return
+
async def started_typing(self, target_user, auth_user, room_id, timeout):
target_user_id = target_user.to_string()
auth_user_id = auth_user.to_string()
@@ -179,35 +293,11 @@ class TypingHandler(object):
def _push_update(self, member, typing):
if self.hs.is_mine_id(member.user_id):
# Only send updates for changes to our own users.
- run_in_background(self._push_remote, member, typing)
-
- self._push_update_local(member=member, typing=typing)
-
- async def _push_remote(self, member, typing):
- try:
- users = await self.store.get_users_in_room(member.room_id)
- self._member_last_federation_poke[member] = self.clock.time_msec()
-
- now = self.clock.time_msec()
- self.wheel_timer.insert(
- now=now, obj=member, then=now + FEDERATION_PING_INTERVAL
+ run_as_background_process(
+ "typing._push_remote", self._push_remote, member, typing
)
- for domain in {get_domain_from_id(u) for u in users}:
- if domain != self.server_name:
- logger.debug("sending typing update to %s", domain)
- self.federation.build_and_send_edu(
- destination=domain,
- edu_type="m.typing",
- content={
- "room_id": member.room_id,
- "user_id": member.user_id,
- "typing": typing,
- },
- key=member,
- )
- except Exception:
- logger.exception("Error pushing typing notif to remotes")
+ self._push_update_local(member=member, typing=typing)
async def _recv_edu(self, origin, content):
room_id = content["room_id"]
@@ -304,8 +394,11 @@ class TypingHandler(object):
return rows, current_id, limited
- def get_current_token(self):
- return self._latest_room_serial
+ def process_replication_rows(
+ self, token: int, rows: List[TypingStream.TypingStreamRow]
+ ):
+ # The writing process should never get updates from replication.
+ raise Exception("Typing writer instance got typing info over replication")
class TypingNotificationEventSource(object):
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index c6c0e623c1..2101517575 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -733,37 +733,54 @@ def trace(func=None, opname=None):
_opname = opname if opname else func.__name__
- @wraps(func)
- def _trace_inner(*args, **kwargs):
- if opentracing is None:
- return func(*args, **kwargs)
+ if inspect.iscoroutinefunction(func):
- scope = start_active_span(_opname)
- scope.__enter__()
+ @wraps(func)
+ async def _trace_inner(*args, **kwargs):
+ if opentracing is None:
+ return await func(*args, **kwargs)
- try:
- result = func(*args, **kwargs)
- if isinstance(result, defer.Deferred):
+ with start_active_span(_opname) as scope:
+ try:
+ return await func(*args, **kwargs)
+ except Exception:
+ scope.span.set_tag(tags.ERROR, True)
+ raise
- def call_back(result):
- scope.__exit__(None, None, None)
- return result
+ else:
+ # The other case here handles both sync functions and those
+ # decorated with inlineDeferred.
+ @wraps(func)
+ def _trace_inner(*args, **kwargs):
+ if opentracing is None:
+ return func(*args, **kwargs)
- def err_back(result):
- scope.span.set_tag(tags.ERROR, True)
- scope.__exit__(None, None, None)
- return result
+ scope = start_active_span(_opname)
+ scope.__enter__()
+
+ try:
+ result = func(*args, **kwargs)
+ if isinstance(result, defer.Deferred):
+
+ def call_back(result):
+ scope.__exit__(None, None, None)
+ return result
- result.addCallbacks(call_back, err_back)
+ def err_back(result):
+ scope.span.set_tag(tags.ERROR, True)
+ scope.__exit__(None, None, None)
+ return result
- else:
- scope.__exit__(None, None, None)
+ result.addCallbacks(call_back, err_back)
+
+ else:
+ scope.__exit__(None, None, None)
- return result
+ return result
- except Exception as e:
- scope.__exit__(type(e), None, e.__traceback__)
- raise
+ except Exception as e:
+ scope.__exit__(type(e), None, e.__traceback__)
+ raise
return _trace_inner
diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py
index 99049bb5d8..fea774e2e5 100644
--- a/synapse/logging/utils.py
+++ b/synapse/logging/utils.py
@@ -14,9 +14,7 @@
# limitations under the License.
-import inspect
import logging
-import time
from functools import wraps
from inspect import getcallargs
@@ -74,127 +72,3 @@ def log_function(f):
wrapped.__name__ = func_name
return wrapped
-
-
-def time_function(f):
- func_name = f.__name__
-
- @wraps(f)
- def wrapped(*args, **kwargs):
- global _TIME_FUNC_ID
- id = _TIME_FUNC_ID
- _TIME_FUNC_ID += 1
-
- start = time.clock()
-
- try:
- _log_debug_as_f(f, "[FUNC START] {%s-%d}", (func_name, id))
-
- r = f(*args, **kwargs)
- finally:
- end = time.clock()
- _log_debug_as_f(
- f, "[FUNC END] {%s-%d} %.3f sec", (func_name, id, end - start)
- )
-
- return r
-
- return wrapped
-
-
-def trace_function(f):
- func_name = f.__name__
- linenum = f.func_code.co_firstlineno
- pathname = f.func_code.co_filename
-
- @wraps(f)
- def wrapped(*args, **kwargs):
- name = f.__module__
- logger = logging.getLogger(name)
- level = logging.DEBUG
-
- frame = inspect.currentframe()
- if frame is None:
- raise Exception("Can't get current frame!")
-
- s = frame.f_back
-
- to_print = [
- "\t%s:%s %s. Args: args=%s, kwargs=%s"
- % (pathname, linenum, func_name, args, kwargs)
- ]
- while s:
- if True or s.f_globals["__name__"].startswith("synapse"):
- filename, lineno, function, _, _ = inspect.getframeinfo(s)
- args_string = inspect.formatargvalues(*inspect.getargvalues(s))
-
- to_print.append(
- "\t%s:%d %s. Args: %s" % (filename, lineno, function, args_string)
- )
-
- s = s.f_back
-
- msg = "\nTraceback for %s:\n" % (func_name,) + "\n".join(to_print)
-
- record = logging.LogRecord(
- name=name,
- level=level,
- pathname=pathname,
- lineno=lineno,
- msg=msg,
- args=(),
- exc_info=None,
- )
-
- logger.handle(record)
-
- return f(*args, **kwargs)
-
- wrapped.__name__ = func_name
- return wrapped
-
-
-def get_previous_frames():
-
- frame = inspect.currentframe()
- if frame is None:
- raise Exception("Can't get current frame!")
-
- s = frame.f_back.f_back
- to_return = []
- while s:
- if s.f_globals["__name__"].startswith("synapse"):
- filename, lineno, function, _, _ = inspect.getframeinfo(s)
- args_string = inspect.formatargvalues(*inspect.getargvalues(s))
-
- to_return.append(
- "{{ %s:%d %s - Args: %s }}" % (filename, lineno, function, args_string)
- )
-
- s = s.f_back
-
- return ", ".join(to_return)
-
-
-def get_previous_frame(ignore=[]):
- frame = inspect.currentframe()
- if frame is None:
- raise Exception("Can't get current frame!")
- s = frame.f_back.f_back
-
- while s:
- if s.f_globals["__name__"].startswith("synapse"):
- if not any(s.f_globals["__name__"].startswith(ig) for ig in ignore):
- filename, lineno, function, _, _ = inspect.getframeinfo(s)
- args_string = inspect.formatargvalues(*inspect.getargvalues(s))
-
- return "{{ %s:%d %s - Args: %s }}" % (
- filename,
- lineno,
- function,
- args_string,
- )
-
- s = s.f_back
-
- return None
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index f6a5458681..2456f12f46 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -15,13 +15,12 @@
# limitations under the License.
import logging
-from collections import defaultdict
-from threading import Lock
-from typing import Dict, Tuple, Union
+from typing import TYPE_CHECKING, Dict, Union
+
+from prometheus_client import Gauge
from twisted.internet import defer
-from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import PusherConfigException
from synapse.push.emailpusher import EmailPusher
@@ -29,9 +28,18 @@ from synapse.push.httppusher import HttpPusher
from synapse.push.pusher import PusherFactory
from synapse.util.async_helpers import concurrently_execute
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
logger = logging.getLogger(__name__)
+synapse_pushers = Gauge(
+ "synapse_pushers", "Number of active synapse pushers", ["kind", "app_id"]
+)
+
+
class PusherPool:
"""
The pusher pool. This is responsible for dispatching notifications of new events to
@@ -47,36 +55,20 @@ class PusherPool:
Pusher.on_new_receipts are not expected to return deferreds.
"""
- def __init__(self, _hs):
- self.hs = _hs
- self.pusher_factory = PusherFactory(_hs)
- self._should_start_pushers = _hs.config.start_pushers
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.pusher_factory = PusherFactory(hs)
+ self._should_start_pushers = hs.config.start_pushers
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
+ # We shard the handling of push notifications by user ID.
+ self._pusher_shard_config = hs.config.push.pusher_shard_config
+ self._instance_name = hs.get_instance_name()
+
# map from user id to app_id:pushkey to pusher
self.pushers = {} # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]]
- # a lock for the pushers dict, since `count_pushers` is called from an different
- # and we otherwise get concurrent modification errors
- self._pushers_lock = Lock()
-
- def count_pushers():
- results = defaultdict(int) # type: Dict[Tuple[str, str], int]
- with self._pushers_lock:
- for pushers in self.pushers.values():
- for pusher in pushers.values():
- k = (type(pusher).__name__, pusher.app_id)
- results[k] += 1
- return results
-
- LaterGauge(
- name="synapse_pushers",
- desc="the number of active pushers",
- labels=["kind", "app_id"],
- caller=count_pushers,
- )
-
def start(self):
"""Starts the pushers off in a background process.
"""
@@ -104,6 +96,7 @@ class PusherPool:
Returns:
Deferred[EmailPusher|HttpPusher]
"""
+
time_now_msec = self.clock.time_msec()
# we try to create the pusher just to validate the config: it
@@ -176,6 +169,9 @@ class PusherPool:
access_tokens (Iterable[int]): access token *ids* to remove pushers
for
"""
+ if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
+ return
+
tokens = set(access_tokens)
for p in (yield self.store.get_pushers_by_user_id(user_id)):
if p["access_token"] in tokens:
@@ -237,6 +233,9 @@ class PusherPool:
if not self._should_start_pushers:
return
+ if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
+ return
+
resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
pusher_dict = None
@@ -275,6 +274,11 @@ class PusherPool:
Returns:
Deferred[EmailPusher|HttpPusher]
"""
+ if not self._pusher_shard_config.should_handle(
+ self._instance_name, pusherdict["user_name"]
+ ):
+ return
+
try:
p = self.pusher_factory.create_pusher(pusherdict)
except PusherConfigException as e:
@@ -298,11 +302,12 @@ class PusherPool:
appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"])
- with self._pushers_lock:
- byuser = self.pushers.setdefault(pusherdict["user_name"], {})
- if appid_pushkey in byuser:
- byuser[appid_pushkey].on_stop()
- byuser[appid_pushkey] = p
+ byuser = self.pushers.setdefault(pusherdict["user_name"], {})
+ if appid_pushkey in byuser:
+ byuser[appid_pushkey].on_stop()
+ byuser[appid_pushkey] = p
+
+ synapse_pushers.labels(type(p).__name__, p.app_id).inc()
# Check if there *may* be push to process. We do this as this check is a
# lot cheaper to do than actually fetching the exact rows we need to
@@ -330,9 +335,10 @@ class PusherPool:
if appid_pushkey in byuser:
logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
- byuser[appid_pushkey].on_stop()
- with self._pushers_lock:
- del byuser[appid_pushkey]
+ pusher = byuser.pop(appid_pushkey)
+ pusher.on_stop()
+
+ synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
yield self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id, pushkey, user_id
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index bd394f6b00..a8a16dbc71 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -26,7 +26,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
def __init__(self, database: Database, db_conn, hs):
super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs)
self._device_inbox_id_gen = SlavedIdTracker(
- db_conn, "device_max_stream_id", "stream_id"
+ db_conn, "device_inbox", "stream_id"
)
self._device_inbox_stream_cache = StreamChangeCache(
"DeviceInboxStreamChangeCache",
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 80f5df60f9..f88e0a2e40 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -14,9 +14,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
+from typing import (
+ Any,
+ Dict,
+ Iterable,
+ Iterator,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ TypeVar,
+ Union,
+)
from prometheus_client import Counter
+from typing_extensions import Deque
from twisted.internet.protocol import ReconnectingClientFactory
@@ -42,8 +54,8 @@ from synapse.replication.tcp.streams import (
EventsStream,
FederationStream,
Stream,
+ TypingStream,
)
-from synapse.util.async_helpers import Linearizer
logger = logging.getLogger(__name__)
@@ -61,6 +73,12 @@ invalidate_cache_counter = Counter(
user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
+# the type of the entries in _command_queues_by_stream
+_StreamCommandQueue = Deque[
+ Tuple[Union[RdataCommand, PositionCommand], AbstractConnection]
+]
+
+
class ReplicationCommandHandler:
"""Handles incoming commands from replication as well as sending commands
back out to connections.
@@ -96,6 +114,14 @@ class ReplicationCommandHandler:
continue
+ if isinstance(stream, TypingStream):
+ # Only add TypingStream as a source on the instance in charge of
+ # typing.
+ if hs.config.worker.writers.typing == hs.get_instance_name():
+ self._streams_to_replicate.append(stream)
+
+ continue
+
# Only add any other streams if we're on master.
if hs.config.worker_app is not None:
continue
@@ -107,10 +133,6 @@ class ReplicationCommandHandler:
self._streams_to_replicate.append(stream)
- self._position_linearizer = Linearizer(
- "replication_position", clock=self._clock
- )
-
# Map of stream name to batched updates. See RdataCommand for info on
# how batching works.
self._pending_batches = {} # type: Dict[str, List[Any]]
@@ -122,10 +144,6 @@ class ReplicationCommandHandler:
# outgoing replication commands to.)
self._connections = [] # type: List[AbstractConnection]
- # For each connection, the incoming stream names that are coming from
- # that connection.
- self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
-
LaterGauge(
"synapse_replication_tcp_resource_total_connections",
"",
@@ -133,6 +151,32 @@ class ReplicationCommandHandler:
lambda: len(self._connections),
)
+ # When POSITION or RDATA commands arrive, we stick them in a queue and process
+ # them in order in a separate background process.
+
+ # the streams which are currently being processed by _unsafe_process_stream
+ self._processing_streams = set() # type: Set[str]
+
+ # for each stream, a queue of commands that are awaiting processing, and the
+ # connection that they arrived on.
+ self._command_queues_by_stream = {
+ stream_name: _StreamCommandQueue() for stream_name in self._streams
+ }
+
+ # For each connection, the incoming stream names that have received a POSITION
+ # from that connection.
+ self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
+
+ LaterGauge(
+ "synapse_replication_tcp_command_queue",
+ "Number of inbound RDATA/POSITION commands queued for processing",
+ ["stream_name"],
+ lambda: {
+ (stream_name,): len(queue)
+ for stream_name, queue in self._command_queues_by_stream.items()
+ },
+ )
+
self._is_master = hs.config.worker_app is None
self._federation_sender = None
@@ -143,6 +187,64 @@ class ReplicationCommandHandler:
if self._is_master:
self._server_notices_sender = hs.get_server_notices_sender()
+ async def _add_command_to_stream_queue(
+ self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
+ ) -> None:
+ """Queue the given received command for processing
+
+ Adds the given command to the per-stream queue, and processes the queue if
+ necessary
+ """
+ stream_name = cmd.stream_name
+ queue = self._command_queues_by_stream.get(stream_name)
+ if queue is None:
+ logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name)
+ return
+
+ # if we're already processing this stream, stick the new command in the
+ # queue, and we're done.
+ if stream_name in self._processing_streams:
+ queue.append((cmd, conn))
+ return
+
+ # otherwise, process the new command.
+
+ # arguably we should start off a new background process here, but nothing
+ # will be too upset if we don't return for ages, so let's save the overhead
+ # and use the existing logcontext.
+
+ self._processing_streams.add(stream_name)
+ try:
+ # might as well skip the queue for this one, since it must be empty
+ assert not queue
+ await self._process_command(cmd, conn, stream_name)
+
+ # now process any other commands that have built up while we were
+ # dealing with that one.
+ while queue:
+ cmd, conn = queue.popleft()
+ try:
+ await self._process_command(cmd, conn, stream_name)
+ except Exception:
+ logger.exception("Failed to handle command %s", cmd)
+
+ finally:
+ self._processing_streams.discard(stream_name)
+
+ async def _process_command(
+ self,
+ cmd: Union[PositionCommand, RdataCommand],
+ conn: AbstractConnection,
+ stream_name: str,
+ ) -> None:
+ if isinstance(cmd, PositionCommand):
+ await self._process_position(stream_name, conn, cmd)
+ elif isinstance(cmd, RdataCommand):
+ await self._process_rdata(stream_name, conn, cmd)
+ else:
+ # This shouldn't be possible
+ raise Exception("Unrecognised command %s in stream queue", cmd.NAME)
+
def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
using TCP.
@@ -276,63 +378,71 @@ class ReplicationCommandHandler:
stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc()
- try:
- row = STREAMS_MAP[stream_name].parse_row(cmd.row)
- except Exception:
- logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row)
- raise
-
- # We linearize here for two reasons:
+ # We put the received command into a queue here for two reasons:
# 1. so we don't try and concurrently handle multiple rows for the
# same stream, and
# 2. so we don't race with getting a POSITION command and fetching
# missing RDATA.
- with await self._position_linearizer.queue(cmd.stream_name):
- # make sure that we've processed a POSITION for this stream *on this
- # connection*. (A POSITION on another connection is no good, as there
- # is no guarantee that we have seen all the intermediate updates.)
- sbc = self._streams_by_connection.get(conn)
- if not sbc or stream_name not in sbc:
- # Let's drop the row for now, on the assumption we'll receive a
- # `POSITION` soon and we'll catch up correctly then.
- logger.debug(
- "Discarding RDATA for unconnected stream %s -> %s",
- stream_name,
- cmd.token,
- )
- return
-
- if cmd.token is None:
- # I.e. this is part of a batch of updates for this stream (in
- # which case batch until we get an update for the stream with a non
- # None token).
- self._pending_batches.setdefault(stream_name, []).append(row)
- else:
- # Check if this is the last of a batch of updates
- rows = self._pending_batches.pop(stream_name, [])
- rows.append(row)
-
- stream = self._streams.get(stream_name)
- if not stream:
- logger.error("Got RDATA for unknown stream: %s", stream_name)
- return
-
- # Find where we previously streamed up to.
- current_token = stream.current_token(cmd.instance_name)
-
- # Discard this data if this token is earlier than the current
- # position. Note that streams can be reset (in which case you
- # expect an earlier token), but that must be preceded by a
- # POSITION command.
- if cmd.token <= current_token:
- logger.debug(
- "Discarding RDATA from stream %s at position %s before previous position %s",
- stream_name,
- cmd.token,
- current_token,
- )
- else:
- await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
+
+ await self._add_command_to_stream_queue(conn, cmd)
+
+ async def _process_rdata(
+ self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
+ ) -> None:
+ """Process an RDATA command
+
+ Called after the command has been popped off the queue of inbound commands
+ """
+ try:
+ row = STREAMS_MAP[stream_name].parse_row(cmd.row)
+ except Exception as e:
+ raise Exception(
+ "Failed to parse RDATA: %r %r" % (stream_name, cmd.row)
+ ) from e
+
+ # make sure that we've processed a POSITION for this stream *on this
+ # connection*. (A POSITION on another connection is no good, as there
+ # is no guarantee that we have seen all the intermediate updates.)
+ sbc = self._streams_by_connection.get(conn)
+ if not sbc or stream_name not in sbc:
+ # Let's drop the row for now, on the assumption we'll receive a
+ # `POSITION` soon and we'll catch up correctly then.
+ logger.debug(
+ "Discarding RDATA for unconnected stream %s -> %s",
+ stream_name,
+ cmd.token,
+ )
+ return
+
+ if cmd.token is None:
+ # I.e. this is part of a batch of updates for this stream (in
+ # which case batch until we get an update for the stream with a non
+ # None token).
+ self._pending_batches.setdefault(stream_name, []).append(row)
+ return
+
+ # Check if this is the last of a batch of updates
+ rows = self._pending_batches.pop(stream_name, [])
+ rows.append(row)
+
+ stream = self._streams[stream_name]
+
+ # Find where we previously streamed up to.
+ current_token = stream.current_token(cmd.instance_name)
+
+ # Discard this data if this token is earlier than the current
+ # position. Note that streams can be reset (in which case you
+ # expect an earlier token), but that must be preceded by a
+ # POSITION command.
+ if cmd.token <= current_token:
+ logger.debug(
+ "Discarding RDATA from stream %s at position %s before previous position %s",
+ stream_name,
+ cmd.token,
+ current_token,
+ )
+ else:
+ await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
@@ -358,67 +468,65 @@ class ReplicationCommandHandler:
logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())
- stream_name = cmd.stream_name
- stream = self._streams.get(stream_name)
- if not stream:
- logger.error("Got POSITION for unknown stream: %s", stream_name)
- return
+ await self._add_command_to_stream_queue(conn, cmd)
- # We protect catching up with a linearizer in case the replication
- # connection reconnects under us.
- with await self._position_linearizer.queue(stream_name):
- # We're about to go and catch up with the stream, so remove from set
- # of connected streams.
- for streams in self._streams_by_connection.values():
- streams.discard(stream_name)
-
- # We clear the pending batches for the stream as the fetching of the
- # missing updates below will fetch all rows in the batch.
- self._pending_batches.pop(stream_name, [])
-
- # Find where we previously streamed up to.
- current_token = stream.current_token(cmd.instance_name)
-
- # If the position token matches our current token then we're up to
- # date and there's nothing to do. Otherwise, fetch all updates
- # between then and now.
- missing_updates = cmd.token != current_token
- while missing_updates:
- logger.info(
- "Fetching replication rows for '%s' between %i and %i",
- stream_name,
- current_token,
- cmd.token,
- )
- (
- updates,
- current_token,
- missing_updates,
- ) = await stream.get_updates_since(
- cmd.instance_name, current_token, cmd.token
- )
+ async def _process_position(
+ self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
+ ) -> None:
+ """Process a POSITION command
- # TODO: add some tests for this
+ Called after the command has been popped off the queue of inbound commands
+ """
+ stream = self._streams[stream_name]
- # Some streams return multiple rows with the same stream IDs,
- # which need to be processed in batches.
+ # We're about to go and catch up with the stream, so remove from set
+ # of connected streams.
+ for streams in self._streams_by_connection.values():
+ streams.discard(stream_name)
- for token, rows in _batch_updates(updates):
- await self.on_rdata(
- stream_name,
- cmd.instance_name,
- token,
- [stream.parse_row(row) for row in rows],
- )
+ # We clear the pending batches for the stream as the fetching of the
+ # missing updates below will fetch all rows in the batch.
+ self._pending_batches.pop(stream_name, [])
- logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
+ # Find where we previously streamed up to.
+ current_token = stream.current_token(cmd.instance_name)
- # We've now caught up to position sent to us, notify handler.
- await self._replication_data_handler.on_position(
- cmd.stream_name, cmd.instance_name, cmd.token
+ # If the position token matches our current token then we're up to
+ # date and there's nothing to do. Otherwise, fetch all updates
+ # between then and now.
+ missing_updates = cmd.token != current_token
+ while missing_updates:
+ logger.info(
+ "Fetching replication rows for '%s' between %i and %i",
+ stream_name,
+ current_token,
+ cmd.token,
)
+ (updates, current_token, missing_updates) = await stream.get_updates_since(
+ cmd.instance_name, current_token, cmd.token
+ )
+
+ # TODO: add some tests for this
+
+ # Some streams return multiple rows with the same stream IDs,
+ # which need to be processed in batches.
+
+ for token, rows in _batch_updates(updates):
+ await self.on_rdata(
+ stream_name,
+ cmd.instance_name,
+ token,
+ [stream.parse_row(row) for row in rows],
+ )
+
+ logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
+
+ # We've now caught up to position sent to us, notify handler.
+ await self._replication_data_handler.on_position(
+ cmd.stream_name, cmd.instance_name, cmd.token
+ )
- self._streams_by_connection.setdefault(conn, set()).add(stream_name)
+ self._streams_by_connection.setdefault(conn, set()).add(stream_name)
async def on_REMOTE_SERVER_UP(
self, conn: AbstractConnection, cmd: RemoteServerUpCommand
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 9076bbe9f1..7a42de3f7d 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -294,11 +294,12 @@ class TypingStream(Stream):
def __init__(self, hs):
typing_handler = hs.get_typing_handler()
- if hs.config.worker_app is None:
- # on the master, query the typing handler
+ writer_instance = hs.config.worker.writers.typing
+ if writer_instance == hs.get_instance_name():
+ # On the writer, query the typing handler
update_function = typing_handler.get_all_typing_updates
else:
- # Query master process
+ # Query the typing writer process
update_function = make_http_update_function(hs, self.NAME)
super().__init__(
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 1c2a4cce7f..16c63ff4ec 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import heapq
-from collections import Iterable
+from collections.abc import Iterable
from typing import List, Tuple, Type
import attr
diff --git a/synapse/res/templates/mail-Element.css b/synapse/res/templates/mail-Element.css
new file mode 100644
index 0000000000..6a3e36eda1
--- /dev/null
+++ b/synapse/res/templates/mail-Element.css
@@ -0,0 +1,7 @@
+.header {
+ border-bottom: 4px solid #e4f7ed ! important;
+}
+
+.notif_link a, .footer a {
+ color: #76CFA6 ! important;
+}
diff --git a/synapse/res/templates/notice_expiry.html b/synapse/res/templates/notice_expiry.html
index 6b94d8c367..d87311f659 100644
--- a/synapse/res/templates/notice_expiry.html
+++ b/synapse/res/templates/notice_expiry.html
@@ -22,6 +22,8 @@
<img src="http://riot.im/img/external/riot-logo-email.png" width="83" height="83" alt="[Riot]"/>
{% elif app_name == "Vector" %}
<img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/>
+ {% elif app_name == "Element" %}
+ <img src="https://static.element.io/images/email-logo.png" width="83" height="83" alt="[Element]"/>
{% else %}
<img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/>
{% endif %}
diff --git a/synapse/res/templates/notif_mail.html b/synapse/res/templates/notif_mail.html
index 019506e5fb..a2dfeb9e9f 100644
--- a/synapse/res/templates/notif_mail.html
+++ b/synapse/res/templates/notif_mail.html
@@ -22,6 +22,8 @@
<img src="http://riot.im/img/external/riot-logo-email.png" width="83" height="83" alt="[Riot]"/>
{% elif app_name == "Vector" %}
<img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/>
+ {% elif app_name == "Element" %}
+ <img src="https://static.element.io/images/email-logo.png" width="83" height="83" alt="[Element]"/>
{% else %}
<img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/>
{% endif %}
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index dc373bc5a3..1c88c93f38 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -38,6 +38,7 @@ from synapse.rest.admin.rooms import (
DeleteRoomRestServlet,
JoinRoomAliasServlet,
ListRoomRestServlet,
+ RoomMembersRestServlet,
RoomRestServlet,
ShutdownRoomRestServlet,
)
@@ -201,6 +202,7 @@ def register_servlets(hs, http_server):
register_servlets_for_client_rest_resource(hs, http_server)
ListRoomRestServlet(hs).register(http_server)
RoomRestServlet(hs).register(http_server)
+ RoomMembersRestServlet(hs).register(http_server)
DeleteRoomRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server)
PurgeRoomServlet(hs).register(http_server)
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 544be47060..b8c95d045a 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -231,6 +231,31 @@ class RoomRestServlet(RestServlet):
return 200, ret
+class RoomMembersRestServlet(RestServlet):
+ """
+ Get members list of a room.
+ """
+
+ PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/members")
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+
+ async def on_GET(self, request, room_id):
+ await assert_requester_is_admin(self.auth, request)
+
+ ret = await self.store.get_room(room_id)
+ if not ret:
+ raise NotFoundError("Room not found")
+
+ members = await self.store.get_users_in_room(room_id)
+ ret = {"members": members, "total": len(members)}
+
+ return 200, ret
+
+
class JoinRoomAliasServlet(RestServlet):
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index c5a84af047..1a3398316d 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -818,9 +818,18 @@ class RoomTypingRestServlet(RestServlet):
self.typing_handler = hs.get_typing_handler()
self.auth = hs.get_auth()
+ # If we're not on the typing writer instance we should scream if we get
+ # requests.
+ self._is_typing_writer = (
+ hs.config.worker.writers.typing == hs.get_instance_name()
+ )
+
async def on_PUT(self, request, room_id, user_id):
requester = await self.auth.get_user_by_req(request)
+ if not self._is_typing_writer:
+ raise Exception("Got /typing request on instance that is not typing writer")
+
room_id = urlparse.unquote(room_id)
target_user = UserID.from_string(urlparse.unquote(user_id))
diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py
index bc11b4dda4..b21538766d 100644
--- a/synapse/rest/client/v2_alpha/_base.py
+++ b/synapse/rest/client/v2_alpha/_base.py
@@ -22,6 +22,7 @@ from twisted.internet import defer
from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.api.urls import CLIENT_API_PREFIX
+from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@@ -51,7 +52,15 @@ def client_patterns(path_regex, releases=(0,), unstable=True, v1=False):
return patterns
-def set_timeline_upper_limit(filter_json, filter_timeline_limit):
+def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int) -> None:
+ """
+ Enforces a maximum limit of a timeline query.
+
+ Params:
+ filter_json: The timeline query to modify.
+ filter_timeline_limit: The maximum limit to allow, passing -1 will
+ disable enforcing a maximum limit.
+ """
if filter_timeline_limit < 0:
return # no upper limits
timeline = filter_json.get("room", {}).get("timeline", {})
diff --git a/synapse/server.py b/synapse/server.py
index f838a03d71..a34d8149ff 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -44,7 +44,6 @@ from synapse.federation.federation_client import FederationClient
from synapse.federation.federation_server import (
FederationHandlerRegistry,
FederationServer,
- ReplicationFederationHandlerRegistry,
)
from synapse.federation.send_queue import FederationRemoteSendQueue
from synapse.federation.sender import FederationSender
@@ -84,7 +83,7 @@ from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
from synapse.handlers.set_password import SetPasswordHandler
from synapse.handlers.stats import StatsHandler
from synapse.handlers.sync import SyncHandler
-from synapse.handlers.typing import TypingHandler
+from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler
from synapse.handlers.user_directory import UserDirectoryHandler
from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpClient
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
@@ -380,7 +379,10 @@ class HomeServer(object):
return PresenceHandler(self)
def build_typing_handler(self):
- return TypingHandler(self)
+ if self.config.worker.writers.typing == self.get_instance_name():
+ return TypingWriterHandler(self)
+ else:
+ return FollowerTypingHandler(self)
def build_sync_handler(self):
return SyncHandler(self)
@@ -536,10 +538,7 @@ class HomeServer(object):
return RoomMemberMasterHandler(self)
def build_federation_registry(self):
- if self.config.worker_app:
- return ReplicationFederationHandlerRegistry(self)
- else:
- return FederationHandlerRegistry()
+ return FederationHandlerRegistry(self)
def build_server_notices_manager(self):
if self.config.worker_app:
diff --git a/synapse/server.pyi b/synapse/server.pyi
index cd50c721b8..90a673778f 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -148,3 +148,5 @@ class HomeServer(object):
self,
) -> synapse.http.matrixfederationclient.MatrixFederationHttpClient:
pass
+ def should_send_federation(self) -> bool:
+ pass
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index bfce541ca7..985a042869 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -100,8 +100,8 @@ def db_to_json(db_content):
if isinstance(db_content, memoryview):
db_content = db_content.tobytes()
- # Decode it to a Unicode string before feeding it to json.loads, so we
- # consistenty get a Unicode-containing object out.
+ # Decode it to a Unicode string before feeding it to json.loads, since
+ # Python 3.5 does not support deserializing bytes.
if isinstance(db_content, (bytes, bytearray)):
db_content = db_content.decode("utf8")
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 59f3394b0a..018826ef69 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -249,7 +249,10 @@ class BackgroundUpdater(object):
retcol="progress_json",
)
- progress = json.loads(progress_json)
+ # Avoid a circular import.
+ from synapse.storage._base import db_to_json
+
+ progress = db_to_json(progress_json)
time_start = self._clock.time_msec()
items_updated = await update_handler(progress, batch_size)
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index 4b4763c701..932458f651 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -128,7 +128,7 @@ class DataStore(
db_conn, "presence_stream", "stream_id"
)
self._device_inbox_id_gen = StreamIdGenerator(
- db_conn, "device_max_stream_id", "stream_id"
+ db_conn, "device_inbox", "stream_id"
)
self._public_room_id_gen = StreamIdGenerator(
db_conn, "public_room_list_stream", "stream_id"
diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py
index b58f04d00d..33cc372dfd 100644
--- a/synapse/storage/data_stores/main/account_data.py
+++ b/synapse/storage/data_stores/main/account_data.py
@@ -22,7 +22,7 @@ from canonicaljson import json
from twisted.internet import defer
-from synapse.storage._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import Database
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@@ -77,7 +77,7 @@ class AccountDataWorkerStore(SQLBaseStore):
)
global_account_data = {
- row["account_data_type"]: json.loads(row["content"]) for row in rows
+ row["account_data_type"]: db_to_json(row["content"]) for row in rows
}
rows = self.db.simple_select_list_txn(
@@ -90,7 +90,7 @@ class AccountDataWorkerStore(SQLBaseStore):
by_room = {}
for row in rows:
room_data = by_room.setdefault(row["room_id"], {})
- room_data[row["account_data_type"]] = json.loads(row["content"])
+ room_data[row["account_data_type"]] = db_to_json(row["content"])
return global_account_data, by_room
@@ -113,7 +113,7 @@ class AccountDataWorkerStore(SQLBaseStore):
)
if result:
- return json.loads(result)
+ return db_to_json(result)
else:
return None
@@ -137,7 +137,7 @@ class AccountDataWorkerStore(SQLBaseStore):
)
return {
- row["account_data_type"]: json.loads(row["content"]) for row in rows
+ row["account_data_type"]: db_to_json(row["content"]) for row in rows
}
return self.db.runInteraction(
@@ -170,7 +170,7 @@ class AccountDataWorkerStore(SQLBaseStore):
allow_none=True,
)
- return json.loads(content_json) if content_json else None
+ return db_to_json(content_json) if content_json else None
return self.db.runInteraction(
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
@@ -255,7 +255,7 @@ class AccountDataWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id, stream_id))
- global_account_data = {row[0]: json.loads(row[1]) for row in txn}
+ global_account_data = {row[0]: db_to_json(row[1]) for row in txn}
sql = (
"SELECT room_id, account_data_type, content FROM room_account_data"
@@ -267,7 +267,7 @@ class AccountDataWorkerStore(SQLBaseStore):
account_data_by_room = {}
for row in txn:
room_account_data = account_data_by_room.setdefault(row[0], {})
- room_account_data[row[1]] = json.loads(row[2])
+ room_account_data[row[1]] = db_to_json(row[2])
return global_account_data, account_data_by_room
diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py
index 7a1fe8cdd2..56659fed37 100644
--- a/synapse/storage/data_stores/main/appservice.py
+++ b/synapse/storage/data_stores/main/appservice.py
@@ -22,7 +22,7 @@ from twisted.internet import defer
from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices
-from synapse.storage._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.database import Database
@@ -303,7 +303,7 @@ class ApplicationServiceTransactionWorkerStore(
if not entry:
return None
- event_ids = json.loads(entry["event_ids"])
+ event_ids = db_to_json(entry["event_ids"])
events = yield self.get_events_as_list(event_ids)
diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index d313b9705f..da297b31fb 100644
--- a/synapse/storage/data_stores/main/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -21,7 +21,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace
-from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database
from synapse.util.caches.expiringcache import ExpiringCache
@@ -65,7 +65,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
messages = []
for row in txn:
stream_pos = row[0]
- messages.append(json.loads(row[1]))
+ messages.append(db_to_json(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return messages, stream_pos
@@ -173,7 +173,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
messages = []
for row in txn:
stream_pos = row[0]
- messages.append(json.loads(row[1]))
+ messages.append(db_to_json(row[1]))
if len(messages) < limit:
log_kv({"message": "Set stream position to current position"})
stream_pos = current_stream_id
@@ -424,9 +424,6 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
def _add_messages_to_local_device_inbox_txn(
self, txn, stream_id, messages_by_user_then_device
):
- sql = "UPDATE device_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?"
- txn.execute(sql, (stream_id, stream_id))
-
local_by_user_then_device = {}
for user_id, messages_by_device in messages_by_user_then_device.items():
messages_json_for_user = {}
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 343cf9a2d5..45581a6500 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -577,7 +577,7 @@ class DeviceWorkerStore(SQLBaseStore):
rows = yield self.db.execute(
"get_users_whose_signatures_changed", None, sql, user_id, from_key
)
- return {user for row in rows for user in json.loads(row[0])}
+ return {user for row in rows for user in db_to_json(row[0])}
else:
return set()
diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py
index 23f4570c4b..615364f018 100644
--- a/synapse/storage/data_stores/main/e2e_room_keys.py
+++ b/synapse/storage/data_stores/main/e2e_room_keys.py
@@ -14,13 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
+from canonicaljson import json
from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.logging.opentracing import log_kv, trace
-from synapse.storage._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, db_to_json
class EndToEndRoomKeyStore(SQLBaseStore):
@@ -148,7 +148,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
"forwarded_count": row["forwarded_count"],
# is_verified must be returned to the client as a boolean
"is_verified": bool(row["is_verified"]),
- "session_data": json.loads(row["session_data"]),
+ "session_data": db_to_json(row["session_data"]),
}
return sessions
@@ -222,7 +222,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
"first_message_index": row[2],
"forwarded_count": row[3],
"is_verified": row[4],
- "session_data": json.loads(row[5]),
+ "session_data": db_to_json(row[5]),
}
return ret
@@ -319,7 +319,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
retcols=("version", "algorithm", "auth_data", "etag"),
)
- result["auth_data"] = json.loads(result["auth_data"])
+ result["auth_data"] = db_to_json(result["auth_data"])
result["version"] = str(result["version"])
if result["etag"] is None:
result["etag"] = 0
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index 6c3cff82e1..317c07a829 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -366,7 +366,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for row in rows:
user_id = row["user_id"]
key_type = row["keytype"]
- key = json.loads(row["keydata"])
+ key = db_to_json(row["keydata"])
user_info = result.setdefault(user_id, {})
user_info[key_type] = key
diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index bc9f4f08ea..504babaa7e 100644
--- a/synapse/storage/data_stores/main/event_push_actions.py
+++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -21,7 +21,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import LoggingTransaction, SQLBaseStore
+from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
from synapse.storage.database import Database
from synapse.util.caches.descriptors import cachedInlineCallbacks
@@ -58,7 +58,7 @@ def _deserialize_action(actions, is_highlight):
"""Custom deserializer for actions. This allows us to "compress" common actions
"""
if actions:
- return json.loads(actions)
+ return db_to_json(actions)
if is_highlight:
return DEFAULT_HIGHLIGHT_ACTION
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 66f01aad84..6f2e0d15cc 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -20,7 +20,6 @@ from collections import OrderedDict, namedtuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
import attr
-from canonicaljson import json
from prometheus_client import Counter
from twisted.internet import defer
@@ -32,7 +31,7 @@ from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
from synapse.logging.utils import log_function
-from synapse.storage._base import make_in_list_sql_clause
+from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.data_stores.main.search import SearchEntry
from synapse.storage.database import Database, LoggingTransaction
from synapse.storage.util.id_generators import StreamIdGenerator
@@ -236,7 +235,7 @@ class PersistEventsStore:
)
txn.execute(sql + clause, args)
- results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed"))
+ results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed"))
for chunk in batch_iter(event_ids, 100):
yield self.db.runInteraction(
@@ -297,7 +296,7 @@ class PersistEventsStore:
if prev_event_id in existing_prevs:
continue
- soft_failed = json.loads(metadata).get("soft_failed")
+ soft_failed = db_to_json(metadata).get("soft_failed")
if soft_failed or rejected:
to_recursively_check.append(prev_event_id)
existing_prevs.add(prev_event_id)
@@ -583,7 +582,7 @@ class PersistEventsStore:
txn.execute(sql, (room_id, EventTypes.Create, ""))
row = txn.fetchone()
if row:
- event_json = json.loads(row[0])
+ event_json = db_to_json(row[0])
content = event_json.get("content", {})
creator = content.get("creator")
room_version_id = content.get("room_version", RoomVersions.V1.identifier)
diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py
index 62d28f44dc..663c94b24f 100644
--- a/synapse/storage/data_stores/main/events_bg_updates.py
+++ b/synapse/storage/data_stores/main/events_bg_updates.py
@@ -15,12 +15,10 @@
import logging
-from canonicaljson import json
-
from twisted.internet import defer
from synapse.api.constants import EventContentFields
-from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database
logger = logging.getLogger(__name__)
@@ -125,7 +123,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
for row in rows:
try:
event_id = row[1]
- event_json = json.loads(row[2])
+ event_json = db_to_json(row[2])
sender = event_json["sender"]
content = event_json["content"]
@@ -208,7 +206,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
for row in ev_rows:
event_id = row["event_id"]
- event_json = json.loads(row["json"])
+ event_json = db_to_json(row["json"])
try:
origin_server_ts = event_json["origin_server_ts"]
except (KeyError, AttributeError):
@@ -317,7 +315,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
soft_failed = False
if metadata:
- soft_failed = json.loads(metadata).get("soft_failed")
+ soft_failed = db_to_json(metadata).get("soft_failed")
if soft_failed or rejected:
soft_failed_events_to_lookup.add(event_id)
@@ -358,7 +356,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
graph[event_id] = {prev_event_id}
- soft_failed = json.loads(metadata).get("soft_failed")
+ soft_failed = db_to_json(metadata).get("soft_failed")
if soft_failed or rejected:
soft_failed_events_to_lookup.add(event_id)
else:
@@ -543,7 +541,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
last_row_event_id = ""
for (event_id, event_json_raw) in results:
try:
- event_json = json.loads(event_json_raw)
+ event_json = db_to_json(event_json_raw)
self.db.simple_insert_many_txn(
txn=txn,
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 01cad7d4fa..e812c67078 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -21,7 +21,6 @@ import threading
from collections import namedtuple
from typing import List, Optional, Tuple
-from canonicaljson import json
from constantly import NamedConstant, Names
from twisted.internet import defer
@@ -40,7 +39,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
-from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import get_domain_from_id
@@ -611,8 +610,8 @@ class EventsWorkerStore(SQLBaseStore):
if not allow_rejected and rejected_reason:
continue
- d = json.loads(row["json"])
- internal_metadata = json.loads(row["internal_metadata"])
+ d = db_to_json(row["json"])
+ internal_metadata = db_to_json(row["internal_metadata"])
format_version = row["format_version"]
if format_version is None:
@@ -640,7 +639,7 @@ class EventsWorkerStore(SQLBaseStore):
else:
room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
if not room_version:
- logger.error(
+ logger.warning(
"Event %s in room %s has unknown room version %s",
event_id,
d["room_id"],
diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py
index 4fb9f9850c..01ff561e1a 100644
--- a/synapse/storage/data_stores/main/group_server.py
+++ b/synapse/storage/data_stores/main/group_server.py
@@ -21,7 +21,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.api.errors import SynapseError
-from synapse.storage._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, db_to_json
# The category ID for the "default" category. We don't store as null in the
# database to avoid the fun of null != null
@@ -197,7 +197,7 @@ class GroupServerWorkerStore(SQLBaseStore):
categories = {
row[0]: {
"is_public": row[1],
- "profile": json.loads(row[2]),
+ "profile": db_to_json(row[2]),
"order": row[3],
}
for row in txn
@@ -221,7 +221,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return {
row["category_id"]: {
"is_public": row["is_public"],
- "profile": json.loads(row["profile"]),
+ "profile": db_to_json(row["profile"]),
}
for row in rows
}
@@ -235,7 +235,7 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_group_category",
)
- category["profile"] = json.loads(category["profile"])
+ category["profile"] = db_to_json(category["profile"])
return category
@@ -251,7 +251,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return {
row["role_id"]: {
"is_public": row["is_public"],
- "profile": json.loads(row["profile"]),
+ "profile": db_to_json(row["profile"]),
}
for row in rows
}
@@ -265,7 +265,7 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_group_role",
)
- role["profile"] = json.loads(role["profile"])
+ role["profile"] = db_to_json(role["profile"])
return role
@@ -333,7 +333,7 @@ class GroupServerWorkerStore(SQLBaseStore):
roles = {
row[0]: {
"is_public": row[1],
- "profile": json.loads(row[2]),
+ "profile": db_to_json(row[2]),
"order": row[3],
}
for row in txn
@@ -462,7 +462,7 @@ class GroupServerWorkerStore(SQLBaseStore):
now = int(self._clock.time_msec())
if row and now < row["valid_until_ms"]:
- return json.loads(row["attestation_json"])
+ return db_to_json(row["attestation_json"])
return None
@@ -489,7 +489,7 @@ class GroupServerWorkerStore(SQLBaseStore):
"group_id": row[0],
"type": row[1],
"membership": row[2],
- "content": json.loads(row[3]),
+ "content": db_to_json(row[3]),
}
for row in txn
]
@@ -519,7 +519,7 @@ class GroupServerWorkerStore(SQLBaseStore):
"group_id": group_id,
"membership": membership,
"type": gtype,
- "content": json.loads(content_json),
+ "content": db_to_json(content_json),
}
for group_id, membership, gtype, content_json in txn
]
@@ -567,7 +567,7 @@ class GroupServerWorkerStore(SQLBaseStore):
"""
txn.execute(sql, (last_id, current_id, limit))
updates = [
- (stream_id, (group_id, user_id, gtype, json.loads(content_json)))
+ (stream_id, (group_id, user_id, gtype, db_to_json(content_json)))
for stream_id, group_id, user_id, gtype, content_json in txn
]
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index f6e78ca590..d181488db7 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -24,7 +24,7 @@ from twisted.internet import defer
from synapse.push.baserules import list_with_base_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.storage._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.data_stores.main.pusher import PusherWorkerStore
@@ -43,8 +43,8 @@ def _load_rules(rawrules, enabled_map):
ruleslist = []
for rawrule in rawrules:
rule = dict(rawrule)
- rule["conditions"] = json.loads(rawrule["conditions"])
- rule["actions"] = json.loads(rawrule["actions"])
+ rule["conditions"] = db_to_json(rawrule["conditions"])
+ rule["actions"] = db_to_json(rawrule["actions"])
rule["default"] = False
ruleslist.append(rule)
diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py
index 5461016240..e18f1ca87c 100644
--- a/synapse/storage/data_stores/main/pusher.py
+++ b/synapse/storage/data_stores/main/pusher.py
@@ -17,11 +17,11 @@
import logging
from typing import Iterable, Iterator, List, Tuple
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
from twisted.internet import defer
-from synapse.storage._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
logger = logging.getLogger(__name__)
@@ -36,7 +36,7 @@ class PusherWorkerStore(SQLBaseStore):
for r in rows:
dataJson = r["data"]
try:
- r["data"] = json.loads(dataJson)
+ r["data"] = db_to_json(dataJson)
except Exception as e:
logger.warning(
"Invalid JSON in data for pusher %d: %s, %s",
diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py
index 8f5505bd67..1d723f2d34 100644
--- a/synapse/storage/data_stores/main/receipts.py
+++ b/synapse/storage/data_stores/main/receipts.py
@@ -22,7 +22,7 @@ from canonicaljson import json
from twisted.internet import defer
-from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util.async_helpers import ObservableDeferred
@@ -203,7 +203,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
for row in rows:
content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
row["user_id"]
- ] = json.loads(row["data"])
+ ] = db_to_json(row["data"])
return [{"type": "m.receipt", "room_id": room_id, "content": content}]
@@ -260,7 +260,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
event_entry = room_event["content"].setdefault(row["event_id"], {})
receipt_type = event_entry.setdefault(row["receipt_type"], {})
- receipt_type[row["user_id"]] = json.loads(row["data"])
+ receipt_type[row["user_id"]] = db_to_json(row["data"])
results = {
room_id: [results[room_id]] if room_id in results else []
@@ -329,7 +329,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"""
txn.execute(sql, (last_id, current_id, limit))
- updates = [(r[0], r[1:5] + (json.loads(r[5]),)) for r in txn]
+ updates = [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn]
limited = False
upper_bound = current_id
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index efb1a4fb4c..e1b6cded65 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -27,6 +27,8 @@ from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidati
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import Database
+from synapse.storage.types import Cursor
+from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@@ -42,6 +44,10 @@ class RegistrationWorkerStore(SQLBaseStore):
self.config = hs.config
self.clock = hs.get_clock()
+ self._user_id_seq = build_sequence_generator(
+ database.engine, find_max_generated_user_id_localpart, "user_id_seq",
+ )
+
@cached()
def get_user_by_id(self, user_id):
return self.db.simple_select_one(
@@ -561,39 +567,17 @@ class RegistrationWorkerStore(SQLBaseStore):
ret = yield self.db.runInteraction("count_real_users", _count_users)
return ret
- @defer.inlineCallbacks
- def find_next_generated_user_id_localpart(self):
- """
- Gets the localpart of the next generated user ID.
+ async def generate_user_id(self) -> str:
+ """Generate a suitable localpart for a guest user
- Generated user IDs are integers, so we find the largest integer user ID
- already taken and return that plus one.
+ Returns: a (hopefully) free localpart
"""
-
- def _find_next_generated_user_id(txn):
- # We bound between '@0' and '@a' to avoid pulling the entire table
- # out.
- txn.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'")
-
- regex = re.compile(r"^@(\d+):")
-
- max_found = 0
-
- for (user_id,) in txn:
- match = regex.search(user_id)
- if match:
- max_found = max(int(match.group(1)), max_found)
-
- return max_found + 1
-
- return (
- (
- yield self.db.runInteraction(
- "find_next_generated_user_id", _find_next_generated_user_id
- )
- )
+ next_id = await self.db.runInteraction(
+ "generate_user_id", self._user_id_seq.get_next_id_txn
)
+ return str(next_id)
+
async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[str]:
"""Returns user id from threepid
@@ -1653,3 +1637,26 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
keyvalues={"user_id": user_id},
values={"expiration_ts_ms": expiration_ts, "email_sent": False},
)
+
+
+def find_max_generated_user_id_localpart(cur: Cursor) -> int:
+ """
+ Gets the localpart of the max current generated user ID.
+
+ Generated user IDs are integers, so we find the largest integer user ID
+ already taken and return that.
+ """
+
+ # We bound between '@0' and '@a' to avoid pulling the entire table
+ # out.
+ cur.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'")
+
+ regex = re.compile(r"^@(\d+):")
+
+ max_found = 0
+
+ for (user_id,) in cur:
+ match = regex.search(user_id)
+ if match:
+ max_found = max(int(match.group(1)), max_found)
+ return max_found
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index 47f98ba421..93b6380f13 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -28,7 +28,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
-from synapse.storage._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.data_stores.main.search import SearchStore
from synapse.storage.database import Database, LoggingTransaction
from synapse.types import ThirdPartyInstanceID
@@ -693,7 +693,7 @@ class RoomWorkerStore(SQLBaseStore):
next_token = None
for stream_ordering, content_json in txn:
next_token = stream_ordering
- event_json = json.loads(content_json)
+ event_json = db_to_json(content_json)
content = event_json["content"]
content_url = content.get("url")
thumbnail_url = content.get("info", {}).get("thumbnail_url")
@@ -938,7 +938,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
if not row["json"]:
retention_policy = {}
else:
- ev = json.loads(row["json"])
+ ev = db_to_json(row["json"])
retention_policy = ev["content"]
self.db.simple_insert_txn(
@@ -994,7 +994,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
updates = []
for room_id, event_json in txn:
- event_dict = json.loads(event_json)
+ event_dict = db_to_json(event_json)
room_version_id = event_dict.get("content", {}).get(
"room_version", RoomVersions.V1.identifier
)
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index 44bab65eac..29765890ee 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -17,8 +17,6 @@
import logging
from typing import Iterable, List, Set
-from canonicaljson import json
-
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
@@ -27,6 +25,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import (
LoggingTransaction,
SQLBaseStore,
+ db_to_json,
make_in_list_sql_clause,
)
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
@@ -938,7 +937,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
event_id = row["event_id"]
room_id = row["room_id"]
try:
- event_json = json.loads(row["json"])
+ event_json = db_to_json(row["json"])
content = event_json["content"]
except Exception:
continue
diff --git a/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py b/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py
new file mode 100644
index 0000000000..2011f6bceb
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py
@@ -0,0 +1,34 @@
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Adds a postgres SEQUENCE for generating guest user IDs.
+"""
+
+from synapse.storage.data_stores.main.registration import (
+ find_max_generated_user_id_localpart,
+)
+from synapse.storage.engines import PostgresEngine
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ if not isinstance(database_engine, PostgresEngine):
+ return
+
+ next_id = find_max_generated_user_id_localpart(cur) + 1
+ cur.execute("CREATE SEQUENCE user_id_seq START WITH %s", (next_id,))
+
+
+def run_upgrade(*args, **kwargs):
+ pass
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index a8381dc577..d52228297c 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -17,12 +17,10 @@ import logging
import re
from collections import namedtuple
-from canonicaljson import json
-
from twisted.internet import defer
from synapse.api.errors import SynapseError
-from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
@@ -157,7 +155,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
stream_ordering = row["stream_ordering"]
origin_server_ts = row["origin_server_ts"]
try:
- event_json = json.loads(row["json"])
+ event_json = db_to_json(row["json"])
content = event_json["content"]
except Exception:
continue
diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py
index 290317fd94..bd7227773a 100644
--- a/synapse/storage/data_stores/main/tags.py
+++ b/synapse/storage/data_stores/main/tags.py
@@ -21,6 +21,7 @@ from canonicaljson import json
from twisted.internet import defer
+from synapse.storage._base import db_to_json
from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
from synapse.util.caches.descriptors import cached
@@ -49,7 +50,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
tags_by_room = {}
for row in rows:
room_tags = tags_by_room.setdefault(row["room_id"], {})
- room_tags[row["tag"]] = json.loads(row["content"])
+ room_tags[row["tag"]] = db_to_json(row["content"])
return tags_by_room
return deferred
@@ -180,7 +181,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
retcols=("tag", "content"),
desc="get_tags_for_room",
).addCallback(
- lambda rows: {row["tag"]: json.loads(row["content"]) for row in rows}
+ lambda rows: {row["tag"]: db_to_json(row["content"]) for row in rows}
)
diff --git a/synapse/storage/data_stores/main/ui_auth.py b/synapse/storage/data_stores/main/ui_auth.py
index 4c044b1a15..5f1b919748 100644
--- a/synapse/storage/data_stores/main/ui_auth.py
+++ b/synapse/storage/data_stores/main/ui_auth.py
@@ -12,13 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
from typing import Any, Dict, Optional, Union
import attr
+from canonicaljson import json
from synapse.api.errors import StoreError
-from synapse.storage._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.types import JsonDict
from synapse.util import stringutils as stringutils
@@ -118,7 +118,7 @@ class UIAuthWorkerStore(SQLBaseStore):
desc="get_ui_auth_session",
)
- result["clientdict"] = json.loads(result["clientdict"])
+ result["clientdict"] = db_to_json(result["clientdict"])
return UIAuthSessionData(session_id, **result)
@@ -168,7 +168,7 @@ class UIAuthWorkerStore(SQLBaseStore):
retcols=("stage_type", "result"),
desc="get_completed_ui_auth_stages",
):
- results[row["stage_type"]] = json.loads(row["result"])
+ results[row["stage_type"]] = db_to_json(row["result"])
return results
@@ -224,7 +224,7 @@ class UIAuthWorkerStore(SQLBaseStore):
)
# Update it and add it back to the database.
- serverdict = json.loads(result["serverdict"])
+ serverdict = db_to_json(result["serverdict"])
serverdict[key] = value
self.db.simple_update_one_txn(
@@ -254,7 +254,7 @@ class UIAuthWorkerStore(SQLBaseStore):
desc="get_ui_auth_session_data",
)
- serverdict = json.loads(result["serverdict"])
+ serverdict = db_to_json(result["serverdict"])
return serverdict.get(key, default)
diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py
index 5db9f20135..128c09a2cf 100644
--- a/synapse/storage/data_stores/state/store.py
+++ b/synapse/storage/data_stores/state/store.py
@@ -24,6 +24,8 @@ from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore
from synapse.storage.database import Database
from synapse.storage.state import StateFilter
+from synapse.storage.types import Cursor
+from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import StateMap
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache
@@ -92,6 +94,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"*stateGroupMembersCache*", 500000,
)
+ def get_max_state_group_txn(txn: Cursor):
+ txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
+ return txn.fetchone()[0]
+
+ self._state_group_seq_gen = build_sequence_generator(
+ self.database_engine, get_max_state_group_txn, "state_group_id_seq"
+ )
+
@cached(max_entries=10000, iterable=True)
def get_state_group_delta(self, state_group):
"""Given a state group try to return a previous group and a delta between
@@ -386,7 +396,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# AFAIK, this can never happen
raise Exception("current_state_ids cannot be None")
- state_group = self.database_engine.get_next_state_group_id(txn)
+ state_group = self._state_group_seq_gen.get_next_id_txn(txn)
self.db.simple_insert_txn(
txn,
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index ab0bbe4bd3..908cbc79e3 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -91,12 +91,6 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
def lock_table(self, txn, table: str) -> None:
...
- @abc.abstractmethod
- def get_next_state_group_id(self, txn) -> int:
- """Returns an int that can be used as a new state_group ID
- """
- ...
-
@property
@abc.abstractmethod
def server_version(self) -> str:
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index a31588080d..ff39281f85 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -154,12 +154,6 @@ class PostgresEngine(BaseDatabaseEngine):
def lock_table(self, txn, table):
txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
- def get_next_state_group_id(self, txn):
- """Returns an int that can be used as a new state_group ID
- """
- txn.execute("SELECT nextval('state_group_id_seq')")
- return txn.fetchone()[0]
-
@property
def server_version(self):
"""Returns a string giving the server version. For example: '8.1.5'
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 215a949442..8a0f8c89d1 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -96,19 +96,6 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
def lock_table(self, txn, table):
return
- def get_next_state_group_id(self, txn):
- """Returns an int that can be used as a new state_group ID
- """
- # We do application locking here since if we're using sqlite then
- # we are a single process synapse.
- with self._current_state_group_id_lock:
- if self._current_state_group_id is None:
- txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
- self._current_state_group_id = txn.fetchone()[0]
-
- self._current_state_group_id += 1
- return self._current_state_group_id
-
@property
def server_version(self):
"""Gets a string giving the server version. For example: '3.22.0'
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index f89ce0bed2..787cebfbec 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -21,6 +21,7 @@ from typing import Dict, Set, Tuple
from typing_extensions import Deque
from synapse.storage.database import Database, LoggingTransaction
+from synapse.storage.util.sequence import PostgresSequenceGenerator
class IdGenerator(object):
@@ -247,7 +248,6 @@ class MultiWriterIdGenerator:
):
self._db = db
self._instance_name = instance_name
- self._sequence_name = sequence_name
# We lock as some functions may be called from DB threads.
self._lock = threading.Lock()
@@ -260,6 +260,8 @@ class MultiWriterIdGenerator:
# should be less than the minimum of this set (if not empty).
self._unfinished_ids = set() # type: Set[int]
+ self._sequence_gen = PostgresSequenceGenerator(sequence_name)
+
def _load_current_ids(
self, db_conn, table: str, instance_column: str, id_column: str
) -> Dict[str, int]:
@@ -283,9 +285,7 @@ class MultiWriterIdGenerator:
return current_positions
def _load_next_id_txn(self, txn):
- txn.execute("SELECT nextval(?)", (self._sequence_name,))
- (next_id,) = txn.fetchone()
- return next_id
+ return self._sequence_gen.get_next_id_txn(txn)
async def get_next(self):
"""
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
new file mode 100644
index 0000000000..63dfea4220
--- /dev/null
+++ b/synapse/storage/util/sequence.py
@@ -0,0 +1,98 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import abc
+import threading
+from typing import Callable, Optional
+
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.types import Cursor
+
+
+class SequenceGenerator(metaclass=abc.ABCMeta):
+ """A class which generates a unique sequence of integers"""
+
+ @abc.abstractmethod
+ def get_next_id_txn(self, txn: Cursor) -> int:
+ """Gets the next ID in the sequence"""
+ ...
+
+
+class PostgresSequenceGenerator(SequenceGenerator):
+ """An implementation of SequenceGenerator which uses a postgres sequence"""
+
+ def __init__(self, sequence_name: str):
+ self._sequence_name = sequence_name
+
+ def get_next_id_txn(self, txn: Cursor) -> int:
+ txn.execute("SELECT nextval(?)", (self._sequence_name,))
+ return txn.fetchone()[0]
+
+
+GetFirstCallbackType = Callable[[Cursor], int]
+
+
+class LocalSequenceGenerator(SequenceGenerator):
+ """An implementation of SequenceGenerator which uses local locking
+
+ This only works reliably if there are no other worker processes generating IDs at
+ the same time.
+ """
+
+ def __init__(self, get_first_callback: GetFirstCallbackType):
+ """
+ Args:
+ get_first_callback: a callback which is called on the first call to
+ get_next_id_txn; should return the curreent maximum id
+ """
+ # the callback. this is cleared after it is called, so that it can be GCed.
+ self._callback = get_first_callback # type: Optional[GetFirstCallbackType]
+
+ # The current max value, or None if we haven't looked in the DB yet.
+ self._current_max_id = None # type: Optional[int]
+ self._lock = threading.Lock()
+
+ def get_next_id_txn(self, txn: Cursor) -> int:
+ # We do application locking here since if we're using sqlite then
+ # we are a single process synapse.
+ with self._lock:
+ if self._current_max_id is None:
+ assert self._callback is not None
+ self._current_max_id = self._callback(txn)
+ self._callback = None
+
+ self._current_max_id += 1
+ return self._current_max_id
+
+
+def build_sequence_generator(
+ database_engine: BaseDatabaseEngine,
+ get_first_callback: GetFirstCallbackType,
+ sequence_name: str,
+) -> SequenceGenerator:
+ """Get the best impl of SequenceGenerator available
+
+ This uses PostgresSequenceGenerator on postgres, and a locally-locked impl on
+ sqlite.
+
+ Args:
+ database_engine: the database engine we are connected to
+ get_first_callback: a callback which gets the next sequence ID. Used if
+ we're on sqlite.
+ sequence_name: the name of a postgres sequence to use.
+ """
+ if isinstance(database_engine, PostgresEngine):
+ return PostgresSequenceGenerator(sequence_name)
+ else:
+ return LocalSequenceGenerator(get_first_callback)
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index da20523b70..22a857a306 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -12,10 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import inspect
import logging
from twisted.internet import defer
+from twisted.internet.defer import Deferred, fail, succeed
+from twisted.python import failure
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -79,6 +81,28 @@ class Distributor(object):
run_as_background_process(name, self.signals[name].fire, *args, **kwargs)
+def maybeAwaitableDeferred(f, *args, **kw):
+ """
+ Invoke a function that may or may not return a Deferred or an Awaitable.
+
+ This is a modified version of twisted.internet.defer.maybeDeferred.
+ """
+ try:
+ result = f(*args, **kw)
+ except Exception:
+ return fail(failure.Failure(captureVars=Deferred.debug))
+
+ if isinstance(result, Deferred):
+ return result
+ # Handle the additional case of an awaitable being returned.
+ elif inspect.isawaitable(result):
+ return defer.ensureDeferred(result)
+ elif isinstance(result, failure.Failure):
+ return fail(result)
+ else:
+ return succeed(result)
+
+
class Signal(object):
"""A Signal is a dispatch point that stores a list of callables as
observers of it.
@@ -122,7 +146,7 @@ class Signal(object):
),
)
- return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
+ return maybeAwaitableDeferred(observer, *args, **kwargs).addErrback(eb)
deferreds = [run_in_background(do, o) for o in self.observers]
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 08c86e92b8..2e2b40a426 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -17,7 +17,7 @@ import itertools
import random
import re
import string
-from collections import Iterable
+from collections.abc import Iterable
from synapse.api.errors import Codes, SynapseError
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 62b47f6574..6aa322bf3a 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -142,10 +142,8 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self.get_success(self.handler.delete_device(user1, "abc"))
# check the device was deleted
- res = self.handler.get_device(user1, "abc")
- self.pump()
- self.assertIsInstance(
- self.failureResultOf(res).value, synapse.api.errors.NotFoundError
+ self.get_failure(
+ self.handler.get_device(user1, "abc"), synapse.api.errors.NotFoundError
)
# we'd like to check the access token was invalidated, but that's a
@@ -180,10 +178,9 @@ class DeviceTestCase(unittest.HomeserverTestCase):
def test_update_unknown_device(self):
update = {"display_name": "new_display"}
- res = self.handler.update_device("user_id", "unknown_device_id", update)
- self.pump()
- self.assertIsInstance(
- self.failureResultOf(res).value, synapse.api.errors.NotFoundError
+ self.get_failure(
+ self.handler.update_device("user_id", "unknown_device_id", update),
+ synapse.api.errors.NotFoundError,
)
def _record_users(self):
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index cdd093ffa8..210ddcbb88 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -334,10 +334,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
res = None
try:
- yield self.hs.get_device_handler().check_device_registered(
- user_id=local_user,
- device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
- initial_device_display_name="new display name",
+ yield defer.ensureDeferred(
+ self.hs.get_device_handler().check_device_registered(
+ user_id=local_user,
+ device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
+ initial_device_display_name="new display name",
+ )
)
except errors.SynapseError as e:
res = e.code
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index a1f4bde347..42a236aa58 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -70,7 +70,9 @@ class ProfileTestCase(unittest.TestCase):
def test_get_my_name(self):
yield self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
- displayname = yield self.handler.get_displayname(self.frank)
+ displayname = yield defer.ensureDeferred(
+ self.handler.get_displayname(self.frank)
+ )
self.assertEquals("Frank", displayname)
@@ -138,7 +140,9 @@ class ProfileTestCase(unittest.TestCase):
{"displayname": "Alice"}
)
- displayname = yield self.handler.get_displayname(self.alice)
+ displayname = yield defer.ensureDeferred(
+ self.handler.get_displayname(self.alice)
+ )
self.assertEquals(displayname, "Alice")
self.mock_federation.make_query.assert_called_with(
@@ -152,8 +156,10 @@ class ProfileTestCase(unittest.TestCase):
def test_incoming_fed_query(self):
yield self.store.set_profile_displayname("caroline", "Caroline", 1)
- response = yield self.query_handlers["profile"](
- {"user_id": "@caroline:test", "field": "displayname"}
+ response = yield defer.ensureDeferred(
+ self.query_handlers["profile"](
+ {"user_id": "@caroline:test", "field": "displayname"}
+ )
)
self.assertEquals({"displayname": "Caroline"}, response)
@@ -163,8 +169,7 @@ class ProfileTestCase(unittest.TestCase):
yield self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png", 1
)
-
- avatar_url = yield self.handler.get_avatar_url(self.frank)
+ avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
self.assertEquals("http://my.server/me.png", avatar_url)
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
new file mode 100644
index 0000000000..2bdc6edbb1
--- /dev/null
+++ b/tests/replication/test_pusher_shard.py
@@ -0,0 +1,193 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from mock import Mock
+
+from twisted.internet import defer
+
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+
+logger = logging.getLogger(__name__)
+
+
+class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
+ """Checks pusher sharding works
+ """
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ # Register a user who sends a message that we'll get notified about
+ self.other_user_id = self.register_user("otheruser", "pass")
+ self.other_access_token = self.login("otheruser", "pass")
+
+ def default_config(self):
+ conf = super().default_config()
+ conf["start_pushers"] = False
+ return conf
+
+ def _create_pusher_and_send_msg(self, localpart):
+ # Create a user that will get push notifications
+ user_id = self.register_user(localpart, "pass")
+ access_token = self.login(localpart, "pass")
+
+ # Register a pusher
+ user_dict = self.get_success(
+ self.hs.get_datastore().get_user_by_access_token(access_token)
+ )
+ token_id = user_dict["token_id"]
+
+ self.get_success(
+ self.hs.get_pusherpool().add_pusher(
+ user_id=user_id,
+ access_token=token_id,
+ kind="http",
+ app_id="m.http",
+ app_display_name="HTTP Push Notifications",
+ device_display_name="pushy push",
+ pushkey="a@example.com",
+ lang=None,
+ data={"url": "https://push.example.com/push"},
+ )
+ )
+
+ self.pump()
+
+ # Create a room
+ room = self.helper.create_room_as(user_id, tok=access_token)
+
+ # The other user joins
+ self.helper.join(
+ room=room, user=self.other_user_id, tok=self.other_access_token
+ )
+
+ # The other user sends some messages
+ response = self.helper.send(room, body="Hi!", tok=self.other_access_token)
+ event_id = response["event_id"]
+
+ return event_id
+
+ def test_send_push_single_worker(self):
+ """Test that registration works when using a pusher worker.
+ """
+ http_client_mock = Mock(spec_set=["post_json_get_json"])
+ http_client_mock.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
+ {}
+ )
+
+ self.make_worker_hs(
+ "synapse.app.pusher",
+ {"start_pushers": True},
+ proxied_http_client=http_client_mock,
+ )
+
+ event_id = self._create_pusher_and_send_msg("user")
+
+ # Advance time a bit, so the pusher will register something has happened
+ self.pump()
+
+ http_client_mock.post_json_get_json.assert_called_once()
+ self.assertEqual(
+ http_client_mock.post_json_get_json.call_args[0][0],
+ "https://push.example.com/push",
+ )
+ self.assertEqual(
+ event_id,
+ http_client_mock.post_json_get_json.call_args[0][1]["notification"][
+ "event_id"
+ ],
+ )
+
+ def test_send_push_multiple_workers(self):
+ """Test that registration works when using sharded pusher workers.
+ """
+ http_client_mock1 = Mock(spec_set=["post_json_get_json"])
+ http_client_mock1.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
+ {}
+ )
+
+ self.make_worker_hs(
+ "synapse.app.pusher",
+ {
+ "start_pushers": True,
+ "worker_name": "pusher1",
+ "pusher_instances": ["pusher1", "pusher2"],
+ },
+ proxied_http_client=http_client_mock1,
+ )
+
+ http_client_mock2 = Mock(spec_set=["post_json_get_json"])
+ http_client_mock2.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
+ {}
+ )
+
+ self.make_worker_hs(
+ "synapse.app.pusher",
+ {
+ "start_pushers": True,
+ "worker_name": "pusher2",
+ "pusher_instances": ["pusher1", "pusher2"],
+ },
+ proxied_http_client=http_client_mock2,
+ )
+
+ # We choose a user name that we know should go to pusher1.
+ event_id = self._create_pusher_and_send_msg("user2")
+
+ # Advance time a bit, so the pusher will register something has happened
+ self.pump()
+
+ http_client_mock1.post_json_get_json.assert_called_once()
+ http_client_mock2.post_json_get_json.assert_not_called()
+ self.assertEqual(
+ http_client_mock1.post_json_get_json.call_args[0][0],
+ "https://push.example.com/push",
+ )
+ self.assertEqual(
+ event_id,
+ http_client_mock1.post_json_get_json.call_args[0][1]["notification"][
+ "event_id"
+ ],
+ )
+
+ http_client_mock1.post_json_get_json.reset_mock()
+ http_client_mock2.post_json_get_json.reset_mock()
+
+ # Now we choose a user name that we know should go to pusher2.
+ event_id = self._create_pusher_and_send_msg("user4")
+
+ # Advance time a bit, so the pusher will register something has happened
+ self.pump()
+
+ http_client_mock1.post_json_get_json.assert_not_called()
+ http_client_mock2.post_json_get_json.assert_called_once()
+ self.assertEqual(
+ http_client_mock2.post_json_get_json.call_args[0][0],
+ "https://push.example.com/push",
+ )
+ self.assertEqual(
+ event_id,
+ http_client_mock2.post_json_get_json.call_args[0][1]["notification"][
+ "event_id"
+ ],
+ )
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index a80537c4fc..946f06d151 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1136,6 +1136,52 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(room_id_1, channel.json_body["room_id"])
+ def test_room_members(self):
+ """Test that room members can be requested correctly"""
+ # Create two test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ # Have another user join the room
+ user_1 = self.register_user("foo", "pass")
+ user_tok_1 = self.login("foo", "pass")
+ self.helper.join(room_id_1, user_1, tok=user_tok_1)
+
+ # Have another user join the room
+ user_2 = self.register_user("bar", "pass")
+ user_tok_2 = self.login("bar", "pass")
+ self.helper.join(room_id_1, user_2, tok=user_tok_2)
+ self.helper.join(room_id_2, user_2, tok=user_tok_2)
+
+ # Have another user join the room
+ user_3 = self.register_user("foobar", "pass")
+ user_tok_3 = self.login("foobar", "pass")
+ self.helper.join(room_id_2, user_3, tok=user_tok_3)
+
+ url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ self.assertCountEqual(
+ ["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"]
+ )
+ self.assertEqual(channel.json_body["total"], 3)
+
+ url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ self.assertCountEqual(
+ ["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"]
+ )
+ self.assertEqual(channel.json_body["total"], 3)
+
class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 89dcc58b99..87a16d7d7a 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -173,7 +173,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Register a mock on the store so that the incoming update doesn't fail because
# we don't share a room with the user.
store = self.homeserver.get_datastore()
- store.get_rooms_for_user = Mock(return_value=["!someroom:test"])
+ store.get_rooms_for_user = Mock(return_value=succeed(["!someroom:test"]))
# Manually inject a fake device list update. We need this update to include at
# least one prev_id so that the user's device list will need to be retried.
@@ -218,23 +218,26 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Register mock device list retrieval on the federation client.
federation_client = self.homeserver.get_federation_client()
federation_client.query_user_devices = Mock(
- return_value={
- "user_id": remote_user_id,
- "stream_id": 1,
- "devices": [],
- "master_key": {
+ return_value=succeed(
+ {
"user_id": remote_user_id,
- "usage": ["master"],
- "keys": {"ed25519:" + remote_master_key: remote_master_key},
- },
- "self_signing_key": {
- "user_id": remote_user_id,
- "usage": ["self_signing"],
- "keys": {
- "ed25519:" + remote_self_signing_key: remote_self_signing_key
+ "stream_id": 1,
+ "devices": [],
+ "master_key": {
+ "user_id": remote_user_id,
+ "usage": ["master"],
+ "keys": {"ed25519:" + remote_master_key: remote_master_key},
},
- },
- }
+ "self_signing_key": {
+ "user_id": remote_user_id,
+ "usage": ["self_signing"],
+ "keys": {
+ "ed25519:"
+ + remote_self_signing_key: remote_self_signing_key
+ },
+ },
+ }
+ )
)
# Resync the device list.
diff --git a/tox.ini b/tox.ini
index e5aef3c062..8a506a3818 100644
--- a/tox.ini
+++ b/tox.ini
@@ -127,7 +127,7 @@ deps =
black==19.10b0
commands =
python -m black --check --diff .
- /bin/sh -c "flake8 synapse tests scripts scripts-dev synctl {env:PEP8SUFFIX:}"
+ /bin/sh -c "flake8 synapse tests scripts scripts-dev contrib synctl {env:PEP8SUFFIX:}"
{toxinidir}/scripts-dev/config-lint.sh
[testenv:check_isort]
|