diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index 1ead0d0030..8939fda67d 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -5,3 +5,4 @@
* [ ] Pull request is based on the develop branch
* [ ] Pull request includes a [changelog file](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#changelog)
* [ ] Pull request includes a [sign off](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#sign-off)
+* [ ] Code style is correct (run the [linters](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#code-style))
diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst
index a71a4a696b..df81f6e54f 100644
--- a/CONTRIBUTING.rst
+++ b/CONTRIBUTING.rst
@@ -58,10 +58,29 @@ All Matrix projects have a well-defined code-style - and sometimes we've even
got as far as documenting it... For instance, synapse's code style doc lives
at https://github.com/matrix-org/synapse/tree/master/docs/code_style.md.
+To facilitate meeting these criteria you can run ``scripts-dev/lint.sh``
+locally. Since this runs the tools listed in the above document, you'll need
+python 3.6 and to install each tool. **Note that the script does not just
+test/check, but also reformats code, so you may wish to ensure any new code is
+committed first**. By default this script checks all files and can take some
+time; if you alter only certain files, you might wish to specify paths as
+arguments to reduce the run-time.
+
Please ensure your changes match the cosmetic style of the existing project,
and **never** mix cosmetic and functional changes in the same commit, as it
makes it horribly hard to review otherwise.
+Before doing a commit, ensure the changes you've made don't produce
+linting errors. You can do this by running the linters as follows. Ensure to
+commit any files that were corrected.
+
+::
+ # Install the dependencies
+ pip install -U black flake8 isort
+
+ # Run the linter script
+ ./scripts-dev/lint.sh
+
Changelog
~~~~~~~~~
diff --git a/changelog.d/5727.feature b/changelog.d/5727.feature
new file mode 100644
index 0000000000..819bebf2d7
--- /dev/null
+++ b/changelog.d/5727.feature
@@ -0,0 +1 @@
+Add federation support for cross-signing.
diff --git a/changelog.d/6164.doc b/changelog.d/6164.doc
new file mode 100644
index 0000000000..f9395b02b3
--- /dev/null
+++ b/changelog.d/6164.doc
@@ -0,0 +1 @@
+Contributor documentation now mentions script to run linters.
diff --git a/changelog.d/6232.bugfix b/changelog.d/6232.bugfix
new file mode 100644
index 0000000000..12718ba934
--- /dev/null
+++ b/changelog.d/6232.bugfix
@@ -0,0 +1 @@
+Remove a room from a server's public rooms list on room upgrade.
\ No newline at end of file
diff --git a/changelog.d/6238.feature b/changelog.d/6238.feature
new file mode 100644
index 0000000000..d225ac33b6
--- /dev/null
+++ b/changelog.d/6238.feature
@@ -0,0 +1 @@
+Add support for outbound http proxying via http_proxy/HTTPS_PROXY env vars.
diff --git a/changelog.d/6254.bugfix b/changelog.d/6254.bugfix
new file mode 100644
index 0000000000..3181484b88
--- /dev/null
+++ b/changelog.d/6254.bugfix
@@ -0,0 +1 @@
+Make notification of cross-signing signatures work with workers.
diff --git a/changelog.d/6298.misc b/changelog.d/6298.misc
new file mode 100644
index 0000000000..d4190730b2
--- /dev/null
+++ b/changelog.d/6298.misc
@@ -0,0 +1 @@
+Refactor EventContext for clarity.
\ No newline at end of file
diff --git a/changelog.d/6301.feature b/changelog.d/6301.feature
new file mode 100644
index 0000000000..78a187a1dc
--- /dev/null
+++ b/changelog.d/6301.feature
@@ -0,0 +1 @@
+Implement label-based filtering on `/sync` and `/messages` ([MSC2326](https://github.com/matrix-org/matrix-doc/pull/2326)).
diff --git a/changelog.d/6304.misc b/changelog.d/6304.misc
new file mode 100644
index 0000000000..20372b4f7c
--- /dev/null
+++ b/changelog.d/6304.misc
@@ -0,0 +1 @@
+Update the version of black used to 19.10b0.
diff --git a/changelog.d/6305.misc b/changelog.d/6305.misc
new file mode 100644
index 0000000000..f047fc3062
--- /dev/null
+++ b/changelog.d/6305.misc
@@ -0,0 +1 @@
+Add some documentation about worker replication.
diff --git a/changelog.d/6306.bugfix b/changelog.d/6306.bugfix
new file mode 100644
index 0000000000..c7dcbcdce8
--- /dev/null
+++ b/changelog.d/6306.bugfix
@@ -0,0 +1 @@
+Appservice requests will no longer contain a double slash prefix when the appservice url provided ends in a slash.
diff --git a/changelog.d/6312.misc b/changelog.d/6312.misc
new file mode 100644
index 0000000000..55e3e1654d
--- /dev/null
+++ b/changelog.d/6312.misc
@@ -0,0 +1 @@
+Document the use of `lint.sh` for code style enforcement & extend it to run on specified paths only.
diff --git a/changelog.d/6313.bugfix b/changelog.d/6313.bugfix
new file mode 100644
index 0000000000..f4d4a97f00
--- /dev/null
+++ b/changelog.d/6313.bugfix
@@ -0,0 +1 @@
+Fix the `hidden` field in the `devices` table for SQLite versions prior to 3.23.0.
diff --git a/changelog.d/6314.misc b/changelog.d/6314.misc
new file mode 100644
index 0000000000..2369760272
--- /dev/null
+++ b/changelog.d/6314.misc
@@ -0,0 +1 @@
+Replace every instance of `logger.warn` method with `logger.warning` as the former is deprecated.
\ No newline at end of file
diff --git a/contrib/experiments/test_messaging.py b/contrib/experiments/test_messaging.py
index 6b22400a60..3bbbcfa1b4 100644
--- a/contrib/experiments/test_messaging.py
+++ b/contrib/experiments/test_messaging.py
@@ -78,7 +78,7 @@ class InputOutput(object):
m = re.match("^join (\S+)$", line)
if m:
# The `sender` wants to join a room.
- room_name, = m.groups()
+ (room_name,) = m.groups()
self.print_line("%s joining %s" % (self.user, room_name))
self.server.join_room(room_name, self.user, self.user)
# self.print_line("OK.")
@@ -105,7 +105,7 @@ class InputOutput(object):
m = re.match("^backfill (\S+)$", line)
if m:
# we want to backfill a room
- room_name, = m.groups()
+ (room_name,) = m.groups()
self.print_line("backfill %s" % room_name)
self.server.backfill(room_name)
return
diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md
index e099d8a87b..ba9e874d07 100644
--- a/docs/tcp_replication.md
+++ b/docs/tcp_replication.md
@@ -199,7 +199,20 @@ client (C):
#### REPLICATE (C)
- Asks the server to replicate a given stream
+Asks the server to replicate a given stream. The syntax is:
+
+```
+ REPLICATE <stream_name> <token>
+```
+
+Where `<token>` may be either:
+ * a numeric stream_id to stream updates since (exclusive)
+ * `NOW` to stream all subsequent updates.
+
+The `<stream_name>` is the name of a replication stream to subscribe
+to (see [here](../synapse/replication/tcp/streams/_base.py) for a list
+of streams). It can also be `ALL` to subscribe to all known streams,
+in which case the `<token>` must be set to `NOW`.
#### USER_SYNC (C)
diff --git a/mypy.ini b/mypy.ini
index ffadaddc0b..1d77c0ecc8 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -1,8 +1,11 @@
[mypy]
-namespace_packages=True
-plugins=mypy_zope:plugin
-follow_imports=skip
-mypy_path=stubs
+namespace_packages = True
+plugins = mypy_zope:plugin
+follow_imports = normal
+check_untyped_defs = True
+show_error_codes = True
+show_traceback = True
+mypy_path = stubs
[mypy-zope]
ignore_missing_imports = True
diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh
index 02a2ca39e5..34c4854e11 100755
--- a/scripts-dev/lint.sh
+++ b/scripts-dev/lint.sh
@@ -7,7 +7,15 @@
set -e
-isort -y -rc synapse tests scripts-dev scripts
-flake8 synapse tests
-python3 -m black synapse tests scripts-dev scripts
+if [ $# -ge 1 ]
+then
+ files=$*
+else
+ files="synapse tests scripts-dev scripts"
+fi
+
+echo "Linting these locations: $files"
+isort -y -rc $files
+flake8 $files
+python3 -m black $files
./scripts-dev/config-lint.sh
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 312196675e..49c4b85054 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -138,3 +138,10 @@ class LimitBlockingTypes(object):
MONTHLY_ACTIVE_USER = "monthly_active_user"
HS_DISABLED = "hs_disabled"
+
+
+class EventContentFields(object):
+ """Fields found in events' content, regardless of type."""
+
+ # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326
+ LABELS = "org.matrix.labels"
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 9f06556bd2..bec13f08d8 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -20,6 +20,7 @@ from jsonschema import FormatChecker
from twisted.internet import defer
+from synapse.api.constants import EventContentFields
from synapse.api.errors import SynapseError
from synapse.storage.presence import UserPresenceState
from synapse.types import RoomID, UserID
@@ -66,6 +67,10 @@ ROOM_EVENT_FILTER_SCHEMA = {
"contains_url": {"type": "boolean"},
"lazy_load_members": {"type": "boolean"},
"include_redundant_members": {"type": "boolean"},
+ # Include or exclude events with the provided labels.
+ # cf https://github.com/matrix-org/matrix-doc/pull/2326
+ "org.matrix.labels": {"type": "array", "items": {"type": "string"}},
+ "org.matrix.not_labels": {"type": "array", "items": {"type": "string"}},
},
}
@@ -259,6 +264,9 @@ class Filter(object):
self.contains_url = self.filter_json.get("contains_url", None)
+ self.labels = self.filter_json.get("org.matrix.labels", None)
+ self.not_labels = self.filter_json.get("org.matrix.not_labels", [])
+
def filters_all_types(self):
return "*" in self.not_types
@@ -282,6 +290,7 @@ class Filter(object):
room_id = None
ev_type = "m.presence"
contains_url = False
+ labels = []
else:
sender = event.get("sender", None)
if not sender:
@@ -300,10 +309,11 @@ class Filter(object):
content = event.get("content", {})
# check if there is a string url field in the content for filtering purposes
contains_url = isinstance(content.get("url"), text_type)
+ labels = content.get(EventContentFields.LABELS, [])
- return self.check_fields(room_id, sender, ev_type, contains_url)
+ return self.check_fields(room_id, sender, ev_type, labels, contains_url)
- def check_fields(self, room_id, sender, event_type, contains_url):
+ def check_fields(self, room_id, sender, event_type, labels, contains_url):
"""Checks whether the filter matches the given event fields.
Returns:
@@ -313,6 +323,7 @@ class Filter(object):
"rooms": lambda v: room_id == v,
"senders": lambda v: sender == v,
"types": lambda v: _matches_wildcard(event_type, v),
+ "labels": lambda v: v in labels,
}
for name, match_func in literal_keys.items():
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 8997c1f9e7..8d28076d92 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -565,7 +565,7 @@ def run(hs):
"Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)
)
try:
- yield hs.get_simple_http_client().put_json(
+ yield hs.get_proxied_http_client().put_json(
hs.config.report_stats_endpoint, stats
)
except Exception as e:
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 33b3579425..aea3985a5f 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -94,7 +94,9 @@ class ApplicationService(object):
ip_range_whitelist=None,
):
self.token = token
- self.url = url
+ self.url = (
+ url.rstrip("/") if isinstance(url, str) else None
+ ) # url must not end with a slash
self.hs_token = hs_token
self.sender = sender
self.server_name = hostname
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 2d2c1e54df..75bb904718 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -234,8 +234,8 @@ def setup_logging(
# make sure that the first thing we log is a thing we can grep backwards
# for
- logging.warn("***** STARTING SERVER *****")
- logging.warn("Server %s version %s", sys.argv[0], get_version_string(synapse))
+ logging.warning("***** STARTING SERVER *****")
+ logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse))
logging.info("Server hostname: %s", config.server_name)
return logger
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 27cd8a63ff..a269de5482 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -37,9 +37,6 @@ class EventContext:
delta_ids (dict[(str, str), str]): Delta from ``prev_group``.
(type, state_key) -> event_id. ``None`` for an outlier.
- prev_state_events (?): XXX: is this ever set to anything other than
- the empty list?
-
app_service: FIXME
_current_state_ids (dict[(str, str), str]|None):
@@ -51,36 +48,16 @@ class EventContext:
The current state map excluding the current event. None if outlier
or we haven't fetched the state from DB yet.
(type, state_key) -> event_id
-
- _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
- been calculated. None if we haven't started calculating yet
-
- _event_type (str): The type of the event the context is associated with.
- Only set when state has not been fetched yet.
-
- _event_state_key (str|None): The state_key of the event the context is
- associated with. Only set when state has not been fetched yet.
-
- _prev_state_id (str|None): If the event associated with the context is
- a state event, then `_prev_state_id` is the event_id of the state
- that was replaced.
- Only set when state has not been fetched yet.
"""
state_group = attr.ib(default=None)
rejected = attr.ib(default=False)
prev_group = attr.ib(default=None)
delta_ids = attr.ib(default=None)
- prev_state_events = attr.ib(default=attr.Factory(list))
app_service = attr.ib(default=None)
- _current_state_ids = attr.ib(default=None)
_prev_state_ids = attr.ib(default=None)
- _prev_state_id = attr.ib(default=None)
-
- _event_type = attr.ib(default=None)
- _event_state_key = attr.ib(default=None)
- _fetching_state_deferred = attr.ib(default=None)
+ _current_state_ids = attr.ib(default=None)
@staticmethod
def with_state(
@@ -90,7 +67,6 @@ class EventContext:
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
state_group=state_group,
- fetching_state_deferred=defer.succeed(None),
prev_group=prev_group,
delta_ids=delta_ids,
)
@@ -125,7 +101,6 @@ class EventContext:
"rejected": self.rejected,
"prev_group": self.prev_group,
"delta_ids": _encode_state_dict(self.delta_ids),
- "prev_state_events": self.prev_state_events,
"app_service_id": self.app_service.id if self.app_service else None,
}
@@ -141,7 +116,7 @@ class EventContext:
Returns:
EventContext
"""
- context = EventContext(
+ context = _AsyncEventContextImpl(
# We use the state_group and prev_state_id stuff to pull the
# current_state_ids out of the DB and construct prev_state_ids.
prev_state_id=input["prev_state_id"],
@@ -151,7 +126,6 @@ class EventContext:
prev_group=input["prev_group"],
delta_ids=_decode_state_dict(input["delta_ids"]),
rejected=input["rejected"],
- prev_state_events=input["prev_state_events"],
)
app_service_id = input["app_service_id"]
@@ -170,14 +144,7 @@ class EventContext:
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""
-
- if not self._fetching_state_deferred:
- self._fetching_state_deferred = run_in_background(
- self._fill_out_state, store
- )
-
- yield make_deferred_yieldable(self._fetching_state_deferred)
-
+ yield self._ensure_fetched(store)
return self._current_state_ids
@defer.inlineCallbacks
@@ -190,14 +157,7 @@ class EventContext:
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""
-
- if not self._fetching_state_deferred:
- self._fetching_state_deferred = run_in_background(
- self._fill_out_state, store
- )
-
- yield make_deferred_yieldable(self._fetching_state_deferred)
-
+ yield self._ensure_fetched(store)
return self._prev_state_ids
def get_cached_current_state_ids(self):
@@ -211,6 +171,44 @@ class EventContext:
return self._current_state_ids
+ def _ensure_fetched(self, store):
+ return defer.succeed(None)
+
+
+@attr.s(slots=True)
+class _AsyncEventContextImpl(EventContext):
+ """
+ An implementation of EventContext which fetches _current_state_ids and
+ _prev_state_ids from the database on demand.
+
+ Attributes:
+
+ _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
+ been calculated. None if we haven't started calculating yet
+
+ _event_type (str): The type of the event the context is associated with.
+
+ _event_state_key (str): The state_key of the event the context is
+ associated with.
+
+ _prev_state_id (str|None): If the event associated with the context is
+ a state event, then `_prev_state_id` is the event_id of the state
+ that was replaced.
+ """
+
+ _prev_state_id = attr.ib(default=None)
+ _event_type = attr.ib(default=None)
+ _event_state_key = attr.ib(default=None)
+ _fetching_state_deferred = attr.ib(default=None)
+
+ def _ensure_fetched(self, store):
+ if not self._fetching_state_deferred:
+ self._fetching_state_deferred = run_in_background(
+ self._fill_out_state, store
+ )
+
+ return make_deferred_yieldable(self._fetching_state_deferred)
+
@defer.inlineCallbacks
def _fill_out_state(self, store):
"""Called to populate the _current_state_ids and _prev_state_ids
@@ -228,27 +226,6 @@ class EventContext:
else:
self._prev_state_ids = self._current_state_ids
- @defer.inlineCallbacks
- def update_state(
- self, state_group, prev_state_ids, current_state_ids, prev_group, delta_ids
- ):
- """Replace the state in the context
- """
-
- # We need to make sure we wait for any ongoing fetching of state
- # to complete so that the updated state doesn't get clobbered
- if self._fetching_state_deferred:
- yield make_deferred_yieldable(self._fetching_state_deferred)
-
- self.state_group = state_group
- self._prev_state_ids = prev_state_ids
- self.prev_group = prev_group
- self._current_state_ids = current_state_ids
- self.delta_ids = delta_ids
-
- # We need to ensure that that we've marked as having fetched the state
- self._fetching_state_deferred = defer.succeed(None)
-
def _encode_state_dict(state_dict):
"""Since dicts of (type, state_key) -> event_id cannot be serialized in
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 595706d01a..545d719652 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -555,7 +555,7 @@ class FederationClient(FederationBase):
Note that this does not append any events to any graphs.
Args:
- destinations (str): Candidate homeservers which are probably
+ destinations (Iterable[str]): Candidate homeservers which are probably
participating in the room.
room_id (str): The room in which the event will happen.
user_id (str): The user whose membership is being evented.
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index cc75c39476..a5b36b1827 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -192,15 +192,16 @@ class PerDestinationQueue(object):
# We have to keep 2 free slots for presence and rr_edus
limit = MAX_EDUS_PER_TRANSACTION - 2
- device_update_edus, dev_list_id = (
- yield self._get_device_update_edus(limit)
+ device_update_edus, dev_list_id = yield self._get_device_update_edus(
+ limit
)
limit -= len(device_update_edus)
- to_device_edus, device_stream_id = (
- yield self._get_to_device_message_edus(limit)
- )
+ (
+ to_device_edus,
+ device_stream_id,
+ ) = yield self._get_to_device_message_edus(limit)
pending_edus = device_update_edus + to_device_edus
@@ -359,20 +360,20 @@ class PerDestinationQueue(object):
last_device_list = self._last_device_list_stream_id
# Retrieve list of new device updates to send to the destination
- now_stream_id, results = yield self._store.get_devices_by_remote(
+ now_stream_id, results = yield self._store.get_device_updates_by_remote(
self._destination, last_device_list, limit=limit
)
edus = [
Edu(
origin=self._server_name,
destination=self._destination,
- edu_type="m.device_list_update",
+ edu_type=edu_type,
content=content,
)
- for content in results
+ for (edu_type, content) in results
]
- assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs"
+ assert len(edus) <= limit, "get_device_updates_by_remote returned too many EDUs"
return (edus, now_stream_id)
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 38bc67191c..2d7e6df6e4 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -38,9 +38,10 @@ class AccountDataEventSource(object):
{"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id}
)
- account_data, room_account_data = (
- yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
- )
+ (
+ account_data,
+ room_account_data,
+ ) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
for account_data_type, content in account_data.items():
results.append({"type": account_data_type, "content": content})
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 3e9b298154..fe62f78e67 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -73,7 +73,10 @@ class ApplicationServicesHandler(object):
try:
limit = 100
while True:
- upper_bound, events = yield self.store.get_new_events_for_appservice(
+ (
+ upper_bound,
+ events,
+ ) = yield self.store.get_new_events_for_appservice(
self.current_max, limit
)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 16b4617f68..26ef5e150c 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -459,7 +459,18 @@ class DeviceHandler(DeviceWorkerHandler):
@defer.inlineCallbacks
def on_federation_query_user_devices(self, user_id):
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
- return {"user_id": user_id, "stream_id": stream_id, "devices": devices}
+ master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master")
+ self_signing_key = yield self.store.get_e2e_cross_signing_key(
+ user_id, "self_signing"
+ )
+
+ return {
+ "user_id": user_id,
+ "stream_id": stream_id,
+ "devices": devices,
+ "master_key": master_key,
+ "self_signing_key": self_signing_key,
+ }
@defer.inlineCallbacks
def user_left_room(self, user, room_id):
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 526379c6f7..c4632f8984 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -250,7 +250,7 @@ class DirectoryHandler(BaseHandler):
ignore_backoff=True,
)
except CodeMessageException as e:
- logging.warn("Error retrieving alias")
+ logging.warning("Error retrieving alias")
if e.code == 404:
result = None
else:
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 5ea54f60be..f09a0b73c8 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -36,6 +36,8 @@ from synapse.types import (
get_verify_key_from_cross_signing_key,
)
from synapse.util import unwrapFirstError
+from synapse.util.async_helpers import Linearizer
+from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__)
@@ -49,10 +51,19 @@ class E2eKeysHandler(object):
self.is_mine = hs.is_mine
self.clock = hs.get_clock()
+ self._edu_updater = SigningKeyEduUpdater(hs, self)
+
+ federation_registry = hs.get_federation_registry()
+
+ # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+ federation_registry.register_edu_handler(
+ "org.matrix.signing_key_update",
+ self._edu_updater.incoming_signing_key_update,
+ )
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
# "query handler" interface.
- hs.get_federation_registry().register_query_handler(
+ federation_registry.register_query_handler(
"client_keys", self.on_federation_query_client_keys
)
@@ -119,9 +130,10 @@ class E2eKeysHandler(object):
else:
query_list.append((user_id, None))
- user_ids_not_in_cache, remote_results = (
- yield self.store.get_user_devices_from_cache(query_list)
- )
+ (
+ user_ids_not_in_cache,
+ remote_results,
+ ) = yield self.store.get_user_devices_from_cache(query_list)
for user_id, devices in iteritems(remote_results):
user_devices = results.setdefault(user_id, {})
for device_id, device in iteritems(devices):
@@ -207,13 +219,15 @@ class E2eKeysHandler(object):
if user_id in destination_query:
results[user_id] = keys
- for user_id, key in remote_result["master_keys"].items():
- if user_id in destination_query:
- cross_signing_keys["master_keys"][user_id] = key
+ if "master_keys" in remote_result:
+ for user_id, key in remote_result["master_keys"].items():
+ if user_id in destination_query:
+ cross_signing_keys["master_keys"][user_id] = key
- for user_id, key in remote_result["self_signing_keys"].items():
- if user_id in destination_query:
- cross_signing_keys["self_signing_keys"][user_id] = key
+ if "self_signing_keys" in remote_result:
+ for user_id, key in remote_result["self_signing_keys"].items():
+ if user_id in destination_query:
+ cross_signing_keys["self_signing_keys"][user_id] = key
except Exception as e:
failure = _exception_to_failure(e)
@@ -251,7 +265,7 @@ class E2eKeysHandler(object):
Returns:
defer.Deferred[dict[str, dict[str, dict]]]: map from
- (master|self_signing|user_signing) -> user_id -> key
+ (master_keys|self_signing_keys|user_signing_keys) -> user_id -> key
"""
master_keys = {}
self_signing_keys = {}
@@ -343,7 +357,16 @@ class E2eKeysHandler(object):
"""
device_keys_query = query_body.get("device_keys", {})
res = yield self.query_local_devices(device_keys_query)
- return {"device_keys": res}
+ ret = {"device_keys": res}
+
+ # add in the cross-signing keys
+ cross_signing_keys = yield self.get_cross_signing_keys_from_cache(
+ device_keys_query, None
+ )
+
+ ret.update(cross_signing_keys)
+
+ return ret
@trace
@defer.inlineCallbacks
@@ -688,17 +711,21 @@ class E2eKeysHandler(object):
try:
# get our self-signing key to verify the signatures
- _, self_signing_key_id, self_signing_verify_key = yield self._get_e2e_cross_signing_verify_key(
- user_id, "self_signing"
- )
+ (
+ _,
+ self_signing_key_id,
+ self_signing_verify_key,
+ ) = yield self._get_e2e_cross_signing_verify_key(user_id, "self_signing")
# get our master key, since we may have received a signature of it.
# We need to fetch it here so that we know what its key ID is, so
# that we can check if a signature that was sent is a signature of
# the master key or of a device
- master_key, _, master_verify_key = yield self._get_e2e_cross_signing_verify_key(
- user_id, "master"
- )
+ (
+ master_key,
+ _,
+ master_verify_key,
+ ) = yield self._get_e2e_cross_signing_verify_key(user_id, "master")
# fetch our stored devices. This is used to 1. verify
# signatures on the master key, and 2. to compare with what
@@ -838,9 +865,11 @@ class E2eKeysHandler(object):
try:
# get our user-signing key to verify the signatures
- user_signing_key, user_signing_key_id, user_signing_verify_key = yield self._get_e2e_cross_signing_verify_key(
- user_id, "user_signing"
- )
+ (
+ user_signing_key,
+ user_signing_key_id,
+ user_signing_verify_key,
+ ) = yield self._get_e2e_cross_signing_verify_key(user_id, "user_signing")
except SynapseError as e:
failure = _exception_to_failure(e)
for user, devicemap in signatures.items():
@@ -859,7 +888,11 @@ class E2eKeysHandler(object):
try:
# get the target user's master key, to make sure it matches
# what was sent
- master_key, master_key_id, _ = yield self._get_e2e_cross_signing_verify_key(
+ (
+ master_key,
+ master_key_id,
+ _,
+ ) = yield self._get_e2e_cross_signing_verify_key(
target_user, "master", user_id
)
@@ -1047,3 +1080,100 @@ class SignatureListItem:
target_user_id = attr.ib()
target_device_id = attr.ib()
signature = attr.ib()
+
+
+class SigningKeyEduUpdater(object):
+ """Handles incoming signing key updates from federation and updates the DB"""
+
+ def __init__(self, hs, e2e_keys_handler):
+ self.store = hs.get_datastore()
+ self.federation = hs.get_federation_client()
+ self.clock = hs.get_clock()
+ self.e2e_keys_handler = e2e_keys_handler
+
+ self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
+
+ # user_id -> list of updates waiting to be handled.
+ self._pending_updates = {}
+
+ # Recently seen stream ids. We don't bother keeping these in the DB,
+ # but they're useful to have them about to reduce the number of spurious
+ # resyncs.
+ self._seen_updates = ExpiringCache(
+ cache_name="signing_key_update_edu",
+ clock=self.clock,
+ max_len=10000,
+ expiry_ms=30 * 60 * 1000,
+ iterable=True,
+ )
+
+ @defer.inlineCallbacks
+ def incoming_signing_key_update(self, origin, edu_content):
+ """Called on incoming signing key update from federation. Responsible for
+ parsing the EDU and adding to pending updates list.
+
+ Args:
+ origin (string): the server that sent the EDU
+ edu_content (dict): the contents of the EDU
+ """
+
+ user_id = edu_content.pop("user_id")
+ master_key = edu_content.pop("master_key", None)
+ self_signing_key = edu_content.pop("self_signing_key", None)
+
+ if get_domain_from_id(user_id) != origin:
+ logger.warning("Got signing key update edu for %r from %r", user_id, origin)
+ return
+
+ room_ids = yield self.store.get_rooms_for_user(user_id)
+ if not room_ids:
+ # We don't share any rooms with this user. Ignore update, as we
+ # probably won't get any further updates.
+ return
+
+ self._pending_updates.setdefault(user_id, []).append(
+ (master_key, self_signing_key)
+ )
+
+ yield self._handle_signing_key_updates(user_id)
+
+ @defer.inlineCallbacks
+ def _handle_signing_key_updates(self, user_id):
+ """Actually handle pending updates.
+
+ Args:
+ user_id (string): the user whose updates we are processing
+ """
+
+ device_handler = self.e2e_keys_handler.device_handler
+
+ with (yield self._remote_edu_linearizer.queue(user_id)):
+ pending_updates = self._pending_updates.pop(user_id, [])
+ if not pending_updates:
+ # This can happen since we batch updates
+ return
+
+ device_ids = []
+
+ logger.info("pending updates: %r", pending_updates)
+
+ for master_key, self_signing_key in pending_updates:
+ if master_key:
+ yield self.store.set_e2e_cross_signing_key(
+ user_id, "master", master_key
+ )
+ _, verify_key = get_verify_key_from_cross_signing_key(master_key)
+ # verify_key is a VerifyKey from signedjson, which uses
+ # .version to denote the portion of the key ID after the
+ # algorithm and colon, which is the device ID
+ device_ids.append(verify_key.version)
+ if self_signing_key:
+ yield self.store.set_e2e_cross_signing_key(
+ user_id, "self_signing", self_signing_key
+ )
+ _, verify_key = get_verify_key_from_cross_signing_key(
+ self_signing_key
+ )
+ device_ids.append(verify_key.version)
+
+ yield device_handler.notify_device_update(user_id, device_ids)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index d2d9f8c26a..8cafcfdab0 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -45,6 +45,7 @@ from synapse.api.errors import (
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.crypto.event_signing import compute_event_signature
from synapse.event_auth import auth_types_for_event
+from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.logging.context import (
make_deferred_yieldable,
@@ -352,10 +353,11 @@ class FederationHandler(BaseHandler):
# note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped
# by the get_pdu_cache in federation_client.
- remote_state, got_auth_chain = (
- yield self.federation_client.get_state_for_room(
- origin, room_id, p
- )
+ (
+ remote_state,
+ got_auth_chain,
+ ) = yield self.federation_client.get_state_for_room(
+ origin, room_id, p
)
# we want the state *after* p; get_state_for_room returns the
@@ -1105,7 +1107,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def do_invite_join(self, target_hosts, room_id, joinee, content):
""" Attempts to join the `joinee` to the room `room_id` via the
- server `target_host`.
+ servers contained in `target_hosts`.
This first triggers a /make_join/ request that returns a partial
event that we can fill out and sign. This is then sent to the
@@ -1114,6 +1116,15 @@ class FederationHandler(BaseHandler):
We suspend processing of any received events from this room until we
have finished processing the join.
+
+ Args:
+ target_hosts (Iterable[str]): List of servers to attempt to join the room with.
+
+ room_id (str): The ID of the room to join.
+
+ joinee (str): The User ID of the joining user.
+
+ content (dict): The event content to use for the join event.
"""
logger.debug("Joining %s to %s", joinee, room_id)
@@ -1173,6 +1184,22 @@ class FederationHandler(BaseHandler):
yield self._persist_auth_tree(origin, auth_chain, state, event)
+ # Check whether this room is the result of an upgrade of a room we already know
+ # about. If so, migrate over user information
+ predecessor = yield self.store.get_room_predecessor(room_id)
+ if not predecessor:
+ return
+ old_room_id = predecessor["room_id"]
+ logger.debug(
+ "Found predecessor for %s during remote join: %s", room_id, old_room_id
+ )
+
+ # We retrieve the room member handler here as to not cause a cyclic dependency
+ member_handler = self.hs.get_room_member_handler()
+ yield member_handler.transfer_room_state_on_room_upgrade(
+ old_room_id, room_id
+ )
+
logger.debug("Finished joining %s to %s", joinee, room_id)
finally:
room_queue = self.room_queues[room_id]
@@ -1845,14 +1872,7 @@ class FederationHandler(BaseHandler):
if c and c.type == EventTypes.Create:
auth_events[(c.type, c.state_key)] = c
- try:
- yield self.do_auth(origin, event, context, auth_events=auth_events)
- except AuthError as e:
- logger.warning(
- "[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg
- )
-
- context.rejected = RejectedReason.AUTH_ERROR
+ context = yield self.do_auth(origin, event, context, auth_events=auth_events)
if not context.rejected:
yield self._check_for_soft_fail(event, state, backfilled)
@@ -2021,12 +2041,12 @@ class FederationHandler(BaseHandler):
Also NB that this function adds entries to it.
Returns:
- defer.Deferred[None]
+ defer.Deferred[EventContext]: updated context object
"""
room_version = yield self.store.get_room_version(event.room_id)
try:
- yield self._update_auth_events_and_context_for_auth(
+ context = yield self._update_auth_events_and_context_for_auth(
origin, event, context, auth_events
)
except Exception:
@@ -2044,7 +2064,9 @@ class FederationHandler(BaseHandler):
event_auth.check(room_version, event, auth_events=auth_events)
except AuthError as e:
logger.warning("Failed auth resolution for %r because %s", event, e)
- raise e
+ context.rejected = RejectedReason.AUTH_ERROR
+
+ return context
@defer.inlineCallbacks
def _update_auth_events_and_context_for_auth(
@@ -2068,7 +2090,7 @@ class FederationHandler(BaseHandler):
auth_events (dict[(str, str)->synapse.events.EventBase]):
Returns:
- defer.Deferred[None]
+ defer.Deferred[EventContext]: updated context
"""
event_auth_events = set(event.auth_event_ids())
@@ -2107,7 +2129,7 @@ class FederationHandler(BaseHandler):
# The other side isn't around or doesn't implement the
# endpoint, so lets just bail out.
logger.info("Failed to get event auth from remote: %s", e)
- return
+ return context
seen_remotes = yield self.store.have_seen_events(
[e.event_id for e in remote_auth_chain]
@@ -2148,7 +2170,7 @@ class FederationHandler(BaseHandler):
if event.internal_metadata.is_outlier():
logger.info("Skipping auth_event fetch for outlier")
- return
+ return context
# FIXME: Assumes we have and stored all the state for all the
# prev_events
@@ -2157,7 +2179,7 @@ class FederationHandler(BaseHandler):
)
if not different_auth:
- return
+ return context
logger.info(
"auth_events refers to events which are not in our calculated auth "
@@ -2204,10 +2226,12 @@ class FederationHandler(BaseHandler):
auth_events.update(new_state)
- yield self._update_context_for_auth_events(
+ context = yield self._update_context_for_auth_events(
event, context, auth_events, event_key
)
+ return context
+
@defer.inlineCallbacks
def _update_context_for_auth_events(self, event, context, auth_events, event_key):
"""Update the state_ids in an event context after auth event resolution,
@@ -2216,14 +2240,16 @@ class FederationHandler(BaseHandler):
Args:
event (Event): The event we're handling the context for
- context (synapse.events.snapshot.EventContext): event context
- to be updated
+ context (synapse.events.snapshot.EventContext): initial event context
auth_events (dict[(str, str)->str]): Events to update in the event
context.
event_key ((str, str)): (type, state_key) for the current event.
this will not be included in the current_state in the context.
+
+ Returns:
+ Deferred[EventContext]: new event context
"""
state_updates = {
k: a.event_id for k, a in iteritems(auth_events) if k != event_key
@@ -2248,7 +2274,7 @@ class FederationHandler(BaseHandler):
current_state_ids=current_state_ids,
)
- yield context.update_state(
+ return EventContext.with_state(
state_group=state_group,
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
@@ -2441,6 +2467,8 @@ class FederationHandler(BaseHandler):
raise e
yield self._check_signature(event, context)
+
+ # We retrieve the room member handler here as to not cause a cyclic dependency
member_handler = self.hs.get_room_member_handler()
yield member_handler.send_membership_event(None, event, context)
else:
@@ -2501,6 +2529,7 @@ class FederationHandler(BaseHandler):
# though the sender isn't a local user.
event.internal_metadata.send_on_behalf_of = get_domain_from_id(event.sender)
+ # We retrieve the room member handler here as to not cause a cyclic dependency
member_handler = self.hs.get_room_member_handler()
yield member_handler.send_membership_event(None, event, context)
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 49c9e031f9..81dce96f4b 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -128,8 +128,8 @@ class InitialSyncHandler(BaseHandler):
tags_by_room = yield self.store.get_tags_for_user(user_id)
- account_data, account_data_by_room = (
- yield self.store.get_account_data_for_user(user_id)
+ account_data, account_data_by_room = yield self.store.get_account_data_for_user(
+ user_id
)
public_room_ids = yield self.store.get_public_room_ids()
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 0d546d2487..d682dc2b7a 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -76,9 +76,10 @@ class MessageHandler(object):
Raises:
SynapseError if something went wrong.
"""
- membership, membership_event_id = yield self.auth.check_in_room_or_world_readable(
- room_id, user_id
- )
+ (
+ membership,
+ membership_event_id,
+ ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
if membership == Membership.JOIN:
data = yield self.state.get_current_state(room_id, event_type, state_key)
@@ -153,9 +154,10 @@ class MessageHandler(object):
% (user_id, room_id, at_token),
)
else:
- membership, membership_event_id = (
- yield self.auth.check_in_room_or_world_readable(room_id, user_id)
- )
+ (
+ membership,
+ membership_event_id,
+ ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
if membership == Membership.JOIN:
state_ids = yield self.store.get_filtered_current_state_ids(
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 6d8b04efe3..260a4351ca 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -214,9 +214,10 @@ class PaginationHandler(object):
source_config = pagin_config.get_source_config("room")
with (yield self.pagination_lock.read(room_id)):
- membership, member_event_id = yield self.auth.check_in_room_or_world_readable(
- room_id, user_id
- )
+ (
+ membership,
+ member_event_id,
+ ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
if source_config.direction == "b":
# if we're going backwards, we might need to backfill. This
@@ -299,10 +300,8 @@ class PaginationHandler(object):
}
if state:
- chunk["state"] = (
- yield self._event_serializer.serialize_events(
- state, time_now, as_client_event=as_client_event
- )
+ chunk["state"] = yield self._event_serializer.serialize_events(
+ state, time_now, as_client_event=as_client_event
)
return chunk
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 53410f120b..cff6b0d375 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -396,8 +396,8 @@ class RegistrationHandler(BaseHandler):
room_id = room_identifier
elif RoomAlias.is_valid(room_identifier):
room_alias = RoomAlias.from_string(room_identifier)
- room_id, remote_room_hosts = (
- yield room_member_handler.lookup_room_alias(room_alias)
+ room_id, remote_room_hosts = yield room_member_handler.lookup_room_alias(
+ room_alias
)
room_id = room_id.to_string()
else:
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 650bd28abb..e92b2eafd5 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -129,6 +129,7 @@ class RoomCreationHandler(BaseHandler):
old_room_id,
new_version, # args for _upgrade_room
)
+
return ret
@defer.inlineCallbacks
@@ -147,21 +148,22 @@ class RoomCreationHandler(BaseHandler):
# we create and auth the tombstone event before properly creating the new
# room, to check our user has perms in the old room.
- tombstone_event, tombstone_context = (
- yield self.event_creation_handler.create_event(
- requester,
- {
- "type": EventTypes.Tombstone,
- "state_key": "",
- "room_id": old_room_id,
- "sender": user_id,
- "content": {
- "body": "This room has been replaced",
- "replacement_room": new_room_id,
- },
+ (
+ tombstone_event,
+ tombstone_context,
+ ) = yield self.event_creation_handler.create_event(
+ requester,
+ {
+ "type": EventTypes.Tombstone,
+ "state_key": "",
+ "room_id": old_room_id,
+ "sender": user_id,
+ "content": {
+ "body": "This room has been replaced",
+ "replacement_room": new_room_id,
},
- token_id=requester.access_token_id,
- )
+ },
+ token_id=requester.access_token_id,
)
old_room_version = yield self.store.get_room_version(old_room_id)
yield self.auth.check_from_context(
@@ -188,7 +190,12 @@ class RoomCreationHandler(BaseHandler):
requester, old_room_id, new_room_id, old_room_state
)
- # and finally, shut down the PLs in the old room, and update them in the new
+ # Copy over user push rules, tags and migrate room directory state
+ yield self.room_member_handler.transfer_room_state_on_room_upgrade(
+ old_room_id, new_room_id
+ )
+
+ # finally, shut down the PLs in the old room, and update them in the new
# room.
yield self._update_upgraded_room_pls(
requester, old_room_id, new_room_id, old_room_state
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 380e2fad5e..06d09c2947 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -203,10 +203,6 @@ class RoomMemberHandler(object):
prev_member_event = yield self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
- # Copy over user state if we're joining an upgraded room
- yield self.copy_user_state_if_room_upgrade(
- room_id, requester.user.to_string()
- )
yield self._user_joined_room(target, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
@@ -455,11 +451,6 @@ class RoomMemberHandler(object):
requester, remote_room_hosts, room_id, target, content
)
- # Copy over user state if this is a join on an remote upgraded room
- yield self.copy_user_state_if_room_upgrade(
- room_id, requester.user.to_string()
- )
-
return remote_join_response
elif effective_membership_state == Membership.LEAVE:
@@ -498,36 +489,72 @@ class RoomMemberHandler(object):
return res
@defer.inlineCallbacks
- def copy_user_state_if_room_upgrade(self, new_room_id, user_id):
- """Copy user-specific information when they join a new room if that new room is the
+ def transfer_room_state_on_room_upgrade(self, old_room_id, room_id):
+ """Upon our server becoming aware of an upgraded room, either by upgrading a room
+ ourselves or joining one, we can transfer over information from the previous room.
+
+ Copies user state (tags/push rules) for every local user that was in the old room, as
+ well as migrating the room directory state.
+
+ Args:
+ old_room_id (str): The ID of the old room
+
+ room_id (str): The ID of the new room
+
+ Returns:
+ Deferred
+ """
+ # Find all local users that were in the old room and copy over each user's state
+ users = yield self.store.get_users_in_room(old_room_id)
+ yield self.copy_user_state_on_room_upgrade(old_room_id, room_id, users)
+
+ # Add new room to the room directory if the old room was there
+ # Remove old room from the room directory
+ old_room = yield self.store.get_room(old_room_id)
+ if old_room and old_room["is_public"]:
+ yield self.store.set_room_is_public(old_room_id, False)
+ yield self.store.set_room_is_public(room_id, True)
+
+ @defer.inlineCallbacks
+ def copy_user_state_on_room_upgrade(self, old_room_id, new_room_id, user_ids):
+ """Copy user-specific information when they join a new room when that new room is the
result of a room upgrade
Args:
- new_room_id (str): The ID of the room the user is joining
- user_id (str): The ID of the user
+ old_room_id (str): The ID of upgraded room
+ new_room_id (str): The ID of the new room
+ user_ids (Iterable[str]): User IDs to copy state for
Returns:
Deferred
"""
- # Check if the new room is an upgraded room
- predecessor = yield self.store.get_room_predecessor(new_room_id)
- if not predecessor:
- return
logger.debug(
- "Found predecessor for %s: %s. Copying over room tags and push " "rules",
+ "Copying over room tags and push rules from %s to %s for users %s",
+ old_room_id,
new_room_id,
- predecessor,
+ user_ids,
)
- # It is an upgraded room. Copy over old tags
- yield self.copy_room_tags_and_direct_to_room(
- predecessor["room_id"], new_room_id, user_id
- )
- # Copy over push rules
- yield self.store.copy_push_rules_from_room_to_room_for_user(
- predecessor["room_id"], new_room_id, user_id
- )
+ for user_id in user_ids:
+ try:
+ # It is an upgraded room. Copy over old tags
+ yield self.copy_room_tags_and_direct_to_room(
+ old_room_id, new_room_id, user_id
+ )
+ # Copy over push rules
+ yield self.store.copy_push_rules_from_room_to_room_for_user(
+ old_room_id, new_room_id, user_id
+ )
+ except Exception:
+ logger.exception(
+ "Error copying tags and/or push rules from rooms %s to %s for user %s. "
+ "Skipping...",
+ old_room_id,
+ new_room_id,
+ user_id,
+ )
+ continue
@defer.inlineCallbacks
def send_membership_event(self, requester, event, context, ratelimit=True):
@@ -759,22 +786,25 @@ class RoomMemberHandler(object):
if room_avatar_event:
room_avatar_url = room_avatar_event.content.get("url", "")
- token, public_keys, fallback_public_key, display_name = (
- yield self.identity_handler.ask_id_server_for_third_party_invite(
- requester=requester,
- id_server=id_server,
- medium=medium,
- address=address,
- room_id=room_id,
- inviter_user_id=user.to_string(),
- room_alias=canonical_room_alias,
- room_avatar_url=room_avatar_url,
- room_join_rules=room_join_rules,
- room_name=room_name,
- inviter_display_name=inviter_display_name,
- inviter_avatar_url=inviter_avatar_url,
- id_access_token=id_access_token,
- )
+ (
+ token,
+ public_keys,
+ fallback_public_key,
+ display_name,
+ ) = yield self.identity_handler.ask_id_server_for_third_party_invite(
+ requester=requester,
+ id_server=id_server,
+ medium=medium,
+ address=address,
+ room_id=room_id,
+ inviter_user_id=user.to_string(),
+ room_alias=canonical_room_alias,
+ room_avatar_url=room_avatar_url,
+ room_join_rules=room_join_rules,
+ room_name=room_name,
+ inviter_display_name=inviter_display_name,
+ inviter_avatar_url=inviter_avatar_url,
+ id_access_token=id_access_token,
)
yield self.event_creation_handler.create_and_send_nonmember_event(
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index f4d8a60774..56ed262a1f 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -396,15 +396,11 @@ class SearchHandler(BaseHandler):
time_now = self.clock.time_msec()
for context in contexts.values():
- context["events_before"] = (
- yield self._event_serializer.serialize_events(
- context["events_before"], time_now
- )
+ context["events_before"] = yield self._event_serializer.serialize_events(
+ context["events_before"], time_now
)
- context["events_after"] = (
- yield self._event_serializer.serialize_events(
- context["events_after"], time_now
- )
+ context["events_after"] = yield self._event_serializer.serialize_events(
+ context["events_after"], time_now
)
state_results = {}
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 26bc276692..7f7d56390e 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -108,7 +108,10 @@ class StatsHandler(StateDeltasHandler):
user_deltas = {}
# Then count deltas for total_events and total_event_bytes.
- room_count, user_count = yield self.store.get_changes_room_total_events_and_bytes(
+ (
+ room_count,
+ user_count,
+ ) = yield self.store.get_changes_room_total_events_and_bytes(
self.pos, max_pos
)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 43a082dcda..b536d410e5 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1206,10 +1206,11 @@ class SyncHandler(object):
since_token = sync_result_builder.since_token
if since_token and not sync_result_builder.full_state:
- account_data, account_data_by_room = (
- yield self.store.get_updated_account_data_for_user(
- user_id, since_token.account_data_key
- )
+ (
+ account_data,
+ account_data_by_room,
+ ) = yield self.store.get_updated_account_data_for_user(
+ user_id, since_token.account_data_key
)
push_rules_changed = yield self.store.have_push_rules_changed_for_user(
@@ -1221,9 +1222,10 @@ class SyncHandler(object):
sync_config.user
)
else:
- account_data, account_data_by_room = (
- yield self.store.get_account_data_for_user(sync_config.user.to_string())
- )
+ (
+ account_data,
+ account_data_by_room,
+ ) = yield self.store.get_account_data_for_user(sync_config.user.to_string())
account_data["m.push_rules"] = yield self.push_rules_for_user(
sync_config.user
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 29aa1e5aaf..8363d887a9 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -81,7 +81,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
def __init__(self, hs):
super().__init__(hs)
self._enabled = bool(hs.config.recaptcha_private_key)
- self._http_client = hs.get_simple_http_client()
+ self._http_client = hs.get_proxied_http_client()
self._url = hs.config.recaptcha_siteverify_api
self._secret = hs.config.recaptcha_private_key
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 2df5b383b5..d4c285445e 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -45,6 +45,7 @@ from synapse.http import (
cancelled_to_request_timed_out_error,
redact_uri,
)
+from synapse.http.proxyagent import ProxyAgent
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.util.async_helpers import timeout_deferred
@@ -183,7 +184,15 @@ class SimpleHttpClient(object):
using HTTP in Matrix
"""
- def __init__(self, hs, treq_args={}, ip_whitelist=None, ip_blacklist=None):
+ def __init__(
+ self,
+ hs,
+ treq_args={},
+ ip_whitelist=None,
+ ip_blacklist=None,
+ http_proxy=None,
+ https_proxy=None,
+ ):
"""
Args:
hs (synapse.server.HomeServer)
@@ -192,6 +201,8 @@ class SimpleHttpClient(object):
we may not request.
ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can
request if it were otherwise caught in a blacklist.
+ http_proxy (bytes): proxy server to use for http connections. host[:port]
+ https_proxy (bytes): proxy server to use for https connections. host[:port]
"""
self.hs = hs
@@ -236,11 +247,13 @@ class SimpleHttpClient(object):
# The default context factory in Twisted 14.0.0 (which we require) is
# BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser'
- self.agent = Agent(
+ self.agent = ProxyAgent(
self.reactor,
connectTimeout=15,
contextFactory=self.hs.get_http_client_context_factory(),
pool=pool,
+ http_proxy=http_proxy,
+ https_proxy=https_proxy,
)
if self._ip_blacklist:
diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
new file mode 100644
index 0000000000..be7b2ceb8e
--- /dev/null
+++ b/synapse/http/connectproxyclient.py
@@ -0,0 +1,195 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from zope.interface import implementer
+
+from twisted.internet import defer, protocol
+from twisted.internet.error import ConnectError
+from twisted.internet.interfaces import IStreamClientEndpoint
+from twisted.internet.protocol import connectionDone
+from twisted.web import http
+
+logger = logging.getLogger(__name__)
+
+
+class ProxyConnectError(ConnectError):
+ pass
+
+
+@implementer(IStreamClientEndpoint)
+class HTTPConnectProxyEndpoint(object):
+ """An Endpoint implementation which will send a CONNECT request to an http proxy
+
+ Wraps an existing HostnameEndpoint for the proxy.
+
+ When we get the connect() request from the connection pool (via the TLS wrapper),
+ we'll first connect to the proxy endpoint with a ProtocolFactory which will make the
+ CONNECT request. Once that completes, we invoke the protocolFactory which was passed
+ in.
+
+ Args:
+ reactor: the Twisted reactor to use for the connection
+ proxy_endpoint (IStreamClientEndpoint): the endpoint to use to connect to the
+ proxy
+ host (bytes): hostname that we want to CONNECT to
+ port (int): port that we want to connect to
+ """
+
+ def __init__(self, reactor, proxy_endpoint, host, port):
+ self._reactor = reactor
+ self._proxy_endpoint = proxy_endpoint
+ self._host = host
+ self._port = port
+
+ def __repr__(self):
+ return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
+
+ def connect(self, protocolFactory):
+ f = HTTPProxiedClientFactory(self._host, self._port, protocolFactory)
+ d = self._proxy_endpoint.connect(f)
+ # once the tcp socket connects successfully, we need to wait for the
+ # CONNECT to complete.
+ d.addCallback(lambda conn: f.on_connection)
+ return d
+
+
+class HTTPProxiedClientFactory(protocol.ClientFactory):
+ """ClientFactory wrapper that triggers an HTTP proxy CONNECT on connect.
+
+ Once the CONNECT completes, invokes the original ClientFactory to build the
+ HTTP Protocol object and run the rest of the connection.
+
+ Args:
+ dst_host (bytes): hostname that we want to CONNECT to
+ dst_port (int): port that we want to connect to
+ wrapped_factory (protocol.ClientFactory): The original Factory
+ """
+
+ def __init__(self, dst_host, dst_port, wrapped_factory):
+ self.dst_host = dst_host
+ self.dst_port = dst_port
+ self.wrapped_factory = wrapped_factory
+ self.on_connection = defer.Deferred()
+
+ def startedConnecting(self, connector):
+ return self.wrapped_factory.startedConnecting(connector)
+
+ def buildProtocol(self, addr):
+ wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
+
+ return HTTPConnectProtocol(
+ self.dst_host, self.dst_port, wrapped_protocol, self.on_connection
+ )
+
+ def clientConnectionFailed(self, connector, reason):
+ logger.debug("Connection to proxy failed: %s", reason)
+ if not self.on_connection.called:
+ self.on_connection.errback(reason)
+ return self.wrapped_factory.clientConnectionFailed(connector, reason)
+
+ def clientConnectionLost(self, connector, reason):
+ logger.debug("Connection to proxy lost: %s", reason)
+ if not self.on_connection.called:
+ self.on_connection.errback(reason)
+ return self.wrapped_factory.clientConnectionLost(connector, reason)
+
+
+class HTTPConnectProtocol(protocol.Protocol):
+ """Protocol that wraps an existing Protocol to do a CONNECT handshake at connect
+
+ Args:
+ host (bytes): The original HTTP(s) hostname or IPv4 or IPv6 address literal
+ to put in the CONNECT request
+
+ port (int): The original HTTP(s) port to put in the CONNECT request
+
+ wrapped_protocol (interfaces.IProtocol): the original protocol (probably
+ HTTPChannel or TLSMemoryBIOProtocol, but could be anything really)
+
+ connected_deferred (Deferred): a Deferred which will be callbacked with
+ wrapped_protocol when the CONNECT completes
+ """
+
+ def __init__(self, host, port, wrapped_protocol, connected_deferred):
+ self.host = host
+ self.port = port
+ self.wrapped_protocol = wrapped_protocol
+ self.connected_deferred = connected_deferred
+ self.http_setup_client = HTTPConnectSetupClient(self.host, self.port)
+ self.http_setup_client.on_connected.addCallback(self.proxyConnected)
+
+ def connectionMade(self):
+ self.http_setup_client.makeConnection(self.transport)
+
+ def connectionLost(self, reason=connectionDone):
+ if self.wrapped_protocol.connected:
+ self.wrapped_protocol.connectionLost(reason)
+
+ self.http_setup_client.connectionLost(reason)
+
+ if not self.connected_deferred.called:
+ self.connected_deferred.errback(reason)
+
+ def proxyConnected(self, _):
+ self.wrapped_protocol.makeConnection(self.transport)
+
+ self.connected_deferred.callback(self.wrapped_protocol)
+
+ # Get any pending data from the http buf and forward it to the original protocol
+ buf = self.http_setup_client.clearLineBuffer()
+ if buf:
+ self.wrapped_protocol.dataReceived(buf)
+
+ def dataReceived(self, data):
+ # if we've set up the HTTP protocol, we can send the data there
+ if self.wrapped_protocol.connected:
+ return self.wrapped_protocol.dataReceived(data)
+
+ # otherwise, we must still be setting up the connection: send the data to the
+ # setup client
+ return self.http_setup_client.dataReceived(data)
+
+
+class HTTPConnectSetupClient(http.HTTPClient):
+ """HTTPClient protocol to send a CONNECT message for proxies and read the response.
+
+ Args:
+ host (bytes): The hostname to send in the CONNECT message
+ port (int): The port to send in the CONNECT message
+ """
+
+ def __init__(self, host, port):
+ self.host = host
+ self.port = port
+ self.on_connected = defer.Deferred()
+
+ def connectionMade(self):
+ logger.debug("Connected to proxy, sending CONNECT")
+ self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
+ self.endHeaders()
+
+ def handleStatus(self, version, status, message):
+ logger.debug("Got Status: %s %s %s", status, message, version)
+ if status != b"200":
+ raise ProxyConnectError("Unexpected status on CONNECT: %s" % status)
+
+ def handleEndHeaders(self):
+ logger.debug("End Headers")
+ self.on_connected.callback(None)
+
+ def handleResponse(self, body):
+ pass
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
new file mode 100644
index 0000000000..332da02a8d
--- /dev/null
+++ b/synapse/http/proxyagent.py
@@ -0,0 +1,195 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import re
+
+from zope.interface import implementer
+
+from twisted.internet import defer
+from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
+from twisted.python.failure import Failure
+from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase
+from twisted.web.error import SchemeNotSupported
+from twisted.web.iweb import IAgent
+
+from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint
+
+logger = logging.getLogger(__name__)
+
+_VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z")
+
+
+@implementer(IAgent)
+class ProxyAgent(_AgentBase):
+ """An Agent implementation which will use an HTTP proxy if one was requested
+
+ Args:
+ reactor: twisted reactor to place outgoing
+ connections.
+
+ contextFactory (IPolicyForHTTPS): A factory for TLS contexts, to control the
+ verification parameters of OpenSSL. The default is to use a
+ `BrowserLikePolicyForHTTPS`, so unless you have special
+ requirements you can leave this as-is.
+
+ connectTimeout (float): The amount of time that this Agent will wait
+ for the peer to accept a connection.
+
+ bindAddress (bytes): The local address for client sockets to bind to.
+
+ pool (HTTPConnectionPool|None): connection pool to be used. If None, a
+ non-persistent pool instance will be created.
+ """
+
+ def __init__(
+ self,
+ reactor,
+ contextFactory=BrowserLikePolicyForHTTPS(),
+ connectTimeout=None,
+ bindAddress=None,
+ pool=None,
+ http_proxy=None,
+ https_proxy=None,
+ ):
+ _AgentBase.__init__(self, reactor, pool)
+
+ self._endpoint_kwargs = {}
+ if connectTimeout is not None:
+ self._endpoint_kwargs["timeout"] = connectTimeout
+ if bindAddress is not None:
+ self._endpoint_kwargs["bindAddress"] = bindAddress
+
+ self.http_proxy_endpoint = _http_proxy_endpoint(
+ http_proxy, reactor, **self._endpoint_kwargs
+ )
+
+ self.https_proxy_endpoint = _http_proxy_endpoint(
+ https_proxy, reactor, **self._endpoint_kwargs
+ )
+
+ self._policy_for_https = contextFactory
+ self._reactor = reactor
+
+ def request(self, method, uri, headers=None, bodyProducer=None):
+ """
+ Issue a request to the server indicated by the given uri.
+
+ Supports `http` and `https` schemes.
+
+ An existing connection from the connection pool may be used or a new one may be
+ created.
+
+ See also: twisted.web.iweb.IAgent.request
+
+ Args:
+ method (bytes): The request method to use, such as `GET`, `POST`, etc
+
+ uri (bytes): The location of the resource to request.
+
+ headers (Headers|None): Extra headers to send with the request
+
+ bodyProducer (IBodyProducer|None): An object which can generate bytes to
+ make up the body of this request (for example, the properly encoded
+ contents of a file for a file upload). Or, None if the request is to
+ have no body.
+
+ Returns:
+ Deferred[IResponse]: completes when the header of the response has
+ been received (regardless of the response status code).
+ """
+ uri = uri.strip()
+ if not _VALID_URI.match(uri):
+ raise ValueError("Invalid URI {!r}".format(uri))
+
+ parsed_uri = URI.fromBytes(uri)
+ pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
+ request_path = parsed_uri.originForm
+
+ if parsed_uri.scheme == b"http" and self.http_proxy_endpoint:
+ # Cache *all* connections under the same key, since we are only
+ # connecting to a single destination, the proxy:
+ pool_key = ("http-proxy", self.http_proxy_endpoint)
+ endpoint = self.http_proxy_endpoint
+ request_path = uri
+ elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint:
+ endpoint = HTTPConnectProxyEndpoint(
+ self._reactor,
+ self.https_proxy_endpoint,
+ parsed_uri.host,
+ parsed_uri.port,
+ )
+ else:
+ # not using a proxy
+ endpoint = HostnameEndpoint(
+ self._reactor, parsed_uri.host, parsed_uri.port, **self._endpoint_kwargs
+ )
+
+ logger.debug("Requesting %s via %s", uri, endpoint)
+
+ if parsed_uri.scheme == b"https":
+ tls_connection_creator = self._policy_for_https.creatorForNetloc(
+ parsed_uri.host, parsed_uri.port
+ )
+ endpoint = wrapClientTLS(tls_connection_creator, endpoint)
+ elif parsed_uri.scheme == b"http":
+ pass
+ else:
+ return defer.fail(
+ Failure(
+ SchemeNotSupported("Unsupported scheme: %r" % (parsed_uri.scheme,))
+ )
+ )
+
+ return self._requestWithEndpoint(
+ pool_key, endpoint, method, parsed_uri, headers, bodyProducer, request_path
+ )
+
+
+def _http_proxy_endpoint(proxy, reactor, **kwargs):
+ """Parses an http proxy setting and returns an endpoint for the proxy
+
+ Args:
+ proxy (bytes|None): the proxy setting
+ reactor: reactor to be used to connect to the proxy
+ kwargs: other args to be passed to HostnameEndpoint
+
+ Returns:
+ interfaces.IStreamClientEndpoint|None: endpoint to use to connect to the proxy,
+ or None
+ """
+ if proxy is None:
+ return None
+
+ # currently we only support hostname:port. Some apps also support
+ # protocol://<host>[:port], which allows a way of requiring a TLS connection to the
+ # proxy.
+
+ host, port = parse_host_port(proxy, default_port=1080)
+ return HostnameEndpoint(reactor, host, port, **kwargs)
+
+
+def parse_host_port(hostport, default_port=None):
+ # could have sworn we had one of these somewhere else...
+ if b":" in hostport:
+ host, port = hostport.rsplit(b":", 1)
+ try:
+ port = int(port)
+ return host, port
+ except ValueError:
+ # the thing after the : wasn't a valid port; presumably this is an
+ # IPv6 address.
+ pass
+
+ return hostport, default_port
diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
index 3220e985a9..334ddaf39a 100644
--- a/synapse/logging/_structured.py
+++ b/synapse/logging/_structured.py
@@ -185,7 +185,7 @@ DEFAULT_LOGGERS = {"synapse": {"level": "INFO"}}
def parse_drain_configs(
- drains: dict
+ drains: dict,
) -> typing.Generator[DrainConfiguration, None, None]:
"""
Parse the drain configurations.
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 2bbdd11941..1ba7bcd4d8 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -149,9 +149,10 @@ class BulkPushRuleEvaluator(object):
room_members = yield self.store.get_joined_users_from_context(event, context)
- (power_levels, sender_power_level) = (
- yield self._get_power_levels_and_sender_level(event, context)
- )
+ (
+ power_levels,
+ sender_power_level,
+ ) = yield self._get_power_levels_and_sender_level(event, context)
evaluator = PushRuleEvaluatorForEvent(
event, len(room_members), sender_power_level, power_levels
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 42e5b0c0a5..8c818a86bf 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -234,14 +234,12 @@ class EmailPusher(object):
return
self.last_stream_ordering = last_stream_ordering
- pusher_still_exists = (
- yield self.store.update_pusher_last_stream_ordering_and_success(
- self.app_id,
- self.email,
- self.user_id,
- last_stream_ordering,
- self.clock.time_msec(),
- )
+ pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success(
+ self.app_id,
+ self.email,
+ self.user_id,
+ last_stream_ordering,
+ self.clock.time_msec(),
)
if not pusher_still_exists:
# The pusher has been deleted while we were processing, so
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 9a1bb64887..e994037be6 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -103,7 +103,7 @@ class HttpPusher(object):
if "url" not in self.data:
raise PusherConfigException("'url' required in data for HTTP pusher")
self.url = self.data["url"]
- self.http_client = hs.get_simple_http_client()
+ self.http_client = hs.get_proxied_http_client()
self.data_minus_url = {}
self.data_minus_url.update(self.data)
del self.data_minus_url["url"]
@@ -211,14 +211,12 @@ class HttpPusher(object):
http_push_processed_counter.inc()
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action["stream_ordering"]
- pusher_still_exists = (
- yield self.store.update_pusher_last_stream_ordering_and_success(
- self.app_id,
- self.pushkey,
- self.user_id,
- self.last_stream_ordering,
- self.clock.time_msec(),
- )
+ pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success(
+ self.app_id,
+ self.pushkey,
+ self.user_id,
+ self.last_stream_ordering,
+ self.clock.time_msec(),
)
if not pusher_still_exists:
# The pusher has been deleted while we were processing, so
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 08e840fdc2..0f6992202d 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -103,9 +103,7 @@ class PusherPool:
# create the pusher setting last_stream_ordering to the current maximum
# stream ordering in event_push_actions, so it will process
# pushes from this point onwards.
- last_stream_ordering = (
- yield self.store.get_latest_push_action_stream_ordering()
- )
+ last_stream_ordering = yield self.store.get_latest_push_action_stream_ordering()
yield self.store.add_pusher(
user_id=user_id,
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 182cb2a1d8..456bc005a0 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import Dict
import six
@@ -44,7 +45,14 @@ class BaseSlavedStore(SQLBaseStore):
self.hs = hs
- def stream_positions(self):
+ def stream_positions(self) -> Dict[str, int]:
+ """
+ Get the current positions of all the streams this store wants to subscribe to
+
+ Returns:
+ map from stream name to the most recent update we have for
+ that stream (ie, the point we want to start replicating from)
+ """
pos = {}
if self._cache_id_gen:
pos["caches"] = self._cache_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 61557665a7..de50748c30 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -15,6 +15,7 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
from synapse.storage.data_stores.main.devices import DeviceWorkerStore
from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -42,14 +43,22 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
def stream_positions(self):
result = super(SlavedDeviceStore, self).stream_positions()
- result["device_lists"] = self._device_list_id_gen.get_current_token()
+ # The user signature stream uses the same stream ID generator as the
+ # device list stream, so set them both to the device list ID
+ # generator's current token.
+ current_token = self._device_list_id_gen.get_current_token()
+ result[DeviceListsStream.NAME] = current_token
+ result[UserSignatureStream.NAME] = current_token
return result
def process_replication_rows(self, stream_name, token, rows):
- if stream_name == "device_lists":
+ if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token)
for row in rows:
self._invalidate_caches_for_devices(token, row.user_id, row.destination)
+ elif stream_name == UserSignatureStream.NAME:
+ for row in rows:
+ self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedDeviceStore, self).process_replication_rows(
stream_name, token, rows
)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 563ce0fc53..fead78388c 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -16,10 +16,17 @@
"""
import logging
+from typing import Dict
from twisted.internet import defer
from twisted.internet.protocol import ReconnectingClientFactory
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.tcp.protocol import (
+ AbstractReplicationClientHandler,
+ ClientReplicationStreamProtocol,
+)
+
from .commands import (
FederationAckCommand,
InvalidateCacheCommand,
@@ -27,7 +34,6 @@ from .commands import (
UserIpCommand,
UserSyncCommand,
)
-from .protocol import ClientReplicationStreamProtocol
logger = logging.getLogger(__name__)
@@ -42,7 +48,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
maxDelay = 30 # Try at least once every N seconds
- def __init__(self, hs, client_name, handler):
+ def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler):
self.client_name = client_name
self.handler = handler
self.server_name = hs.config.server_name
@@ -68,13 +74,13 @@ class ReplicationClientFactory(ReconnectingClientFactory):
ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
-class ReplicationClientHandler(object):
+class ReplicationClientHandler(AbstractReplicationClientHandler):
"""A base handler that can be passed to the ReplicationClientFactory.
By default proxies incoming replication data to the SlaveStore.
"""
- def __init__(self, store):
+ def __init__(self, store: BaseSlavedStore):
self.store = store
# The current connection. None if we are currently (re)connecting
@@ -138,11 +144,13 @@ class ReplicationClientHandler(object):
if d:
d.callback(data)
- def get_streams_to_replicate(self):
+ def get_streams_to_replicate(self) -> Dict[str, int]:
"""Called when a new connection has been established and we need to
subscribe to streams.
- Returns a dictionary of stream name to token.
+ Returns:
+ map from stream name to the most recent update we have for
+ that stream (ie, the point we want to start replicating from)
"""
args = self.store.stream_positions()
user_account_data = args.pop("user_account_data", None)
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index b64f3f44b5..afaf002fe6 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -48,7 +48,7 @@ indicate which side is sending, these are *not* included on the wire::
> ERROR server stopping
* connection closed by server *
"""
-
+import abc
import fcntl
import logging
import struct
@@ -65,6 +65,7 @@ from twisted.python.failure import Failure
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util import Clock
from synapse.util.stringutils import random_string
from .commands import (
@@ -558,11 +559,80 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.streamer.lost_connection(self)
+class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
+ """
+ The interface for the handler that should be passed to
+ ClientReplicationStreamProtocol
+ """
+
+ @abc.abstractmethod
+ def on_rdata(self, stream_name, token, rows):
+ """Called to handle a batch of replication data with a given stream token.
+
+ Args:
+ stream_name (str): name of the replication stream for this batch of rows
+ token (int): stream token for this batch of rows
+ rows (list): a list of Stream.ROW_TYPE objects as returned by
+ Stream.parse_row.
+
+ Returns:
+ Deferred|None
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def on_position(self, stream_name, token):
+ """Called when we get new position data."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def on_sync(self, data):
+ """Called when get a new SYNC command."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_streams_to_replicate(self):
+ """Called when a new connection has been established and we need to
+ subscribe to streams.
+
+ Returns:
+ map from stream name to the most recent update we have for
+ that stream (ie, the point we want to start replicating from)
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_currently_syncing_users(self):
+ """Get the list of currently syncing users (if any). This is called
+ when a connection has been established and we need to send the
+ currently syncing users."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def update_connection(self, connection):
+ """Called when a connection has been established (or lost with None).
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def finished_connecting(self):
+ """Called when we have successfully subscribed and caught up to all
+ streams we're interested in.
+ """
+ raise NotImplementedError()
+
+
class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
- def __init__(self, client_name, server_name, clock, handler):
+ def __init__(
+ self,
+ client_name: str,
+ server_name: str,
+ clock: Clock,
+ handler: AbstractReplicationClientHandler,
+ ):
BaseReplicationStreamProtocol.__init__(self, clock)
self.client_name = client_name
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 634f636dc9..5f52264e84 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -45,5 +45,6 @@ STREAMS_MAP = {
_base.TagAccountDataStream,
_base.AccountDataStream,
_base.GroupServerStream,
+ _base.UserSignatureStream,
)
}
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index f03111c259..9e45429d49 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -95,6 +95,7 @@ GroupsStreamRow = namedtuple(
"GroupsStreamRow",
("group_id", "user_id", "type", "content"), # str # str # str # dict
)
+UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str
class Stream(object):
@@ -438,3 +439,20 @@ class GroupServerStream(Stream):
self.update_function = store.get_all_groups_changes
super(GroupServerStream, self).__init__(hs)
+
+
+class UserSignatureStream(Stream):
+ """A user has signed their own device with their user-signing key
+ """
+
+ NAME = "user_signature"
+ _LIMITED = False
+ ROW_TYPE = UserSignatureStreamRow
+
+ def __init__(self, hs):
+ store = hs.get_datastore()
+
+ self.current_token = store.get_device_stream_token
+ self.update_function = store.get_all_user_signature_changes_for_remotes
+
+ super(UserSignatureStream, self).__init__(hs)
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 39a5c5e9de..24a0ce74f2 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -203,10 +203,11 @@ class LoginRestServlet(RestServlet):
address = address.lower()
# Check for login providers that support 3pid login types
- canonical_user_id, callback_3pid = (
- yield self.auth_handler.check_password_provider_3pid(
- medium, address, login_submission["password"]
- )
+ (
+ canonical_user_id,
+ callback_3pid,
+ ) = yield self.auth_handler.check_password_provider_3pid(
+ medium, address, login_submission["password"]
)
if canonical_user_id:
# Authentication through password provider and 3pid succeeded
@@ -280,8 +281,8 @@ class LoginRestServlet(RestServlet):
def do_token_login(self, login_submission):
token = login_submission["token"]
auth_handler = self.auth_handler
- user_id = (
- yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
+ user_id = yield auth_handler.validate_short_term_login_token_and_get_user_id(
+ token
)
result = yield self._register_device_with_callback(user_id, login_submission)
@@ -380,7 +381,7 @@ class CasTicketServlet(RestServlet):
self.cas_displayname_attribute = hs.config.cas_displayname_attribute
self.cas_required_attributes = hs.config.cas_required_attributes
self._sso_auth_handler = SSOAuthHandler(hs)
- self._http_client = hs.get_simple_http_client()
+ self._http_client = hs.get_proxied_http_client()
@defer.inlineCallbacks
def on_GET(self, request):
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 332d7138b1..f26eae794c 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -148,7 +148,7 @@ class PasswordResetSubmitTokenServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- self.failure_email_template, = load_jinja2_templates(
+ (self.failure_email_template,) = load_jinja2_templates(
self.config.email_template_dir,
[self.config.email_password_reset_template_failure_html],
)
@@ -479,7 +479,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- self.failure_email_template, = load_jinja2_templates(
+ (self.failure_email_template,) = load_jinja2_templates(
self.config.email_template_dir,
[self.config.email_add_threepid_template_failure_html],
)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 6c7d25d411..91db923814 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -247,13 +247,13 @@ class RegistrationSubmitTokenServlet(RestServlet):
self.store = hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- self.failure_email_template, = load_jinja2_templates(
+ (self.failure_email_template,) = load_jinja2_templates(
self.config.email_template_dir,
[self.config.email_registration_template_failure_html],
)
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- self.failure_email_template, = load_jinja2_templates(
+ (self.failure_email_template,) = load_jinja2_templates(
self.config.email_template_dir,
[self.config.email_registration_template_failure_html],
)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 1044ae7b4e..bb30ce3f34 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -65,6 +65,9 @@ class VersionsRestServlet(RestServlet):
"m.require_identity_server": False,
# as per MSC2290
"m.separate_add_and_bind": True,
+ # Implements support for label-based filtering as described in
+ # MSC2326.
+ "org.matrix.label_based_filtering": True,
},
},
)
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 55580bc59e..e7fc3f0431 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -102,7 +102,7 @@ class RemoteKey(DirectServeResource):
@wrap_json_request_handler
async def _async_render_GET(self, request):
if len(request.postpath) == 1:
- server, = request.postpath
+ (server,) = request.postpath
query = {server.decode("ascii"): {}}
elif len(request.postpath) == 2:
server, key_id = request.postpath
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 5a25b6b3fc..531d923f76 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -74,6 +74,8 @@ class PreviewUrlResource(DirectServeResource):
treq_args={"browser_like_redirects": True},
ip_whitelist=hs.config.url_preview_ip_range_whitelist,
ip_blacklist=hs.config.url_preview_ip_range_blacklist,
+ http_proxy=os.getenv("http_proxy"),
+ https_proxy=os.getenv("HTTPS_PROXY"),
)
self.media_repo = media_repo
self.primary_base_path = media_repo.primary_base_path
diff --git a/synapse/server.py b/synapse/server.py
index 0b81af646c..f8aeebcff8 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -23,6 +23,7 @@
# Imports required for the default HomeServer() implementation
import abc
import logging
+import os
from twisted.enterprise import adbapi
from twisted.mail.smtp import sendmail
@@ -168,6 +169,7 @@ class HomeServer(object):
"filtering",
"http_client_context_factory",
"simple_http_client",
+ "proxied_http_client",
"media_repository",
"media_repository_resource",
"federation_transport_client",
@@ -311,6 +313,13 @@ class HomeServer(object):
def build_simple_http_client(self):
return SimpleHttpClient(self)
+ def build_proxied_http_client(self):
+ return SimpleHttpClient(
+ self,
+ http_proxy=os.getenv("http_proxy"),
+ https_proxy=os.getenv("HTTPS_PROXY"),
+ )
+
def build_room_creation_handler(self):
return RoomCreationHandler(self)
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 16f8f6b573..b5e0b57095 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -12,6 +12,7 @@ import synapse.handlers.message
import synapse.handlers.room
import synapse.handlers.room_member
import synapse.handlers.set_password
+import synapse.http.client
import synapse.rest.media.v1.media_repository
import synapse.server_notices.server_notices_manager
import synapse.server_notices.server_notices_sender
@@ -38,8 +39,16 @@ class HomeServer(object):
pass
def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
pass
+ def get_simple_http_client(self) -> synapse.http.client.SimpleHttpClient:
+ """Fetch an HTTP client implementation which doesn't do any blacklisting
+ or support any HTTP_PROXY settings"""
+ pass
+ def get_proxied_http_client(self) -> synapse.http.client.SimpleHttpClient:
+ """Fetch an HTTP client implementation which doesn't do any blacklisting
+ but does support HTTP_PROXY settings"""
+ pass
def get_deactivate_account_handler(
- self
+ self,
) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
pass
def get_room_creation_handler(self) -> synapse.handlers.room.RoomCreationHandler:
@@ -47,32 +56,32 @@ class HomeServer(object):
def get_room_member_handler(self) -> synapse.handlers.room_member.RoomMemberHandler:
pass
def get_event_creation_handler(
- self
+ self,
) -> synapse.handlers.message.EventCreationHandler:
pass
def get_set_password_handler(
- self
+ self,
) -> synapse.handlers.set_password.SetPasswordHandler:
pass
def get_federation_sender(self) -> synapse.federation.sender.FederationSender:
pass
def get_federation_transport_client(
- self
+ self,
) -> synapse.federation.transport.client.TransportLayerClient:
pass
def get_media_repository_resource(
- self
+ self,
) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource:
pass
def get_media_repository(
- self
+ self,
) -> synapse.rest.media.v1.media_repository.MediaRepository:
pass
def get_server_notices_manager(
- self
+ self,
) -> synapse.server_notices.server_notices_manager.ServerNoticesManager:
pass
def get_server_notices_sender(
- self
+ self,
) -> synapse.server_notices.server_notices_sender.ServerNoticesSender:
pass
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index b185ba0b3e..10c940df1e 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -139,7 +139,10 @@ class DataStore(
db_conn, "public_room_list_stream", "stream_id"
)
self._device_list_id_gen = StreamIdGenerator(
- db_conn, "device_lists_stream", "stream_id"
+ db_conn,
+ "device_lists_stream",
+ "stream_id",
+ extra_tables=[("user_signature_stream", "stream_id")],
)
self._cross_signing_id_gen = StreamIdGenerator(
db_conn, "e2e_cross_signing_keys", "stream_id"
@@ -317,7 +320,7 @@ class DataStore(
) u
"""
txn.execute(sql, (time_from,))
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return count
def count_r30_users(self):
@@ -396,7 +399,7 @@ class DataStore(
txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
results["all"] = count
return results
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index f7a3542348..71f62036c0 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -37,6 +37,7 @@ from synapse.storage._base import (
make_in_list_sql_clause,
)
from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.types import get_verify_key_from_cross_signing_key
from synapse.util import batch_iter
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
@@ -90,13 +91,18 @@ class DeviceWorkerStore(SQLBaseStore):
@trace
@defer.inlineCallbacks
- def get_devices_by_remote(self, destination, from_stream_id, limit):
- """Get stream of updates to send to remote servers
+ def get_device_updates_by_remote(self, destination, from_stream_id, limit):
+ """Get a stream of device updates to send to the given remote server.
+ Args:
+ destination (str): The host the device updates are intended for
+ from_stream_id (int): The minimum stream_id to filter updates by, exclusive
+ limit (int): Maximum number of device updates to return
Returns:
- Deferred[tuple[int, list[dict]]]:
+ Deferred[tuple[int, list[tuple[string,dict]]]]:
current stream id (ie, the stream id of the last update included in the
- response), and the list of updates
+ response), and the list of updates, where each update is a pair of EDU
+ type and EDU contents
"""
now_stream_id = self._device_list_id_gen.get_current_token()
@@ -117,8 +123,8 @@ class DeviceWorkerStore(SQLBaseStore):
# stream_id; the rationale being that such a large device list update
# is likely an error.
updates = yield self.runInteraction(
- "get_devices_by_remote",
- self._get_devices_by_remote_txn,
+ "get_device_updates_by_remote",
+ self._get_device_updates_by_remote_txn,
destination,
from_stream_id,
now_stream_id,
@@ -129,6 +135,37 @@ class DeviceWorkerStore(SQLBaseStore):
if not updates:
return now_stream_id, []
+ # get the cross-signing keys of the users in the list, so that we can
+ # determine which of the device changes were cross-signing keys
+ users = set(r[0] for r in updates)
+ master_key_by_user = {}
+ self_signing_key_by_user = {}
+ for user in users:
+ cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
+ if cross_signing_key:
+ key_id, verify_key = get_verify_key_from_cross_signing_key(
+ cross_signing_key
+ )
+ # verify_key is a VerifyKey from signedjson, which uses
+ # .version to denote the portion of the key ID after the
+ # algorithm and colon, which is the device ID
+ master_key_by_user[user] = {
+ "key_info": cross_signing_key,
+ "device_id": verify_key.version,
+ }
+
+ cross_signing_key = yield self.get_e2e_cross_signing_key(
+ user, "self_signing"
+ )
+ if cross_signing_key:
+ key_id, verify_key = get_verify_key_from_cross_signing_key(
+ cross_signing_key
+ )
+ self_signing_key_by_user[user] = {
+ "key_info": cross_signing_key,
+ "device_id": verify_key.version,
+ }
+
# if we have exceeded the limit, we need to exclude any results with the
# same stream_id as the last row.
if len(updates) > limit:
@@ -153,20 +190,33 @@ class DeviceWorkerStore(SQLBaseStore):
# context which created the Edu.
query_map = {}
- for update in updates:
- if stream_id_cutoff is not None and update[2] >= stream_id_cutoff:
+ cross_signing_keys_by_user = {}
+ for user_id, device_id, update_stream_id, update_context in updates:
+ if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
# Stop processing updates
break
- key = (update[0], update[1])
-
- update_context = update[3]
- update_stream_id = update[2]
+ if (
+ user_id in master_key_by_user
+ and device_id == master_key_by_user[user_id]["device_id"]
+ ):
+ result = cross_signing_keys_by_user.setdefault(user_id, {})
+ result["master_key"] = master_key_by_user[user_id]["key_info"]
+ elif (
+ user_id in self_signing_key_by_user
+ and device_id == self_signing_key_by_user[user_id]["device_id"]
+ ):
+ result = cross_signing_keys_by_user.setdefault(user_id, {})
+ result["self_signing_key"] = self_signing_key_by_user[user_id][
+ "key_info"
+ ]
+ else:
+ key = (user_id, device_id)
- previous_update_stream_id, _ = query_map.get(key, (0, None))
+ previous_update_stream_id, _ = query_map.get(key, (0, None))
- if update_stream_id > previous_update_stream_id:
- query_map[key] = (update_stream_id, update_context)
+ if update_stream_id > previous_update_stream_id:
+ query_map[key] = (update_stream_id, update_context)
# If we didn't find any updates with a stream_id lower than the cutoff, it
# means that there are more than limit updates all of which have the same
@@ -176,16 +226,22 @@ class DeviceWorkerStore(SQLBaseStore):
# devices, in which case E2E isn't going to work well anyway. We'll just
# skip that stream_id and return an empty list, and continue with the next
# stream_id next time.
- if not query_map:
+ if not query_map and not cross_signing_keys_by_user:
return stream_id_cutoff, []
results = yield self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map
)
+ # add the updated cross-signing keys to the results list
+ for user_id, result in iteritems(cross_signing_keys_by_user):
+ result["user_id"] = user_id
+ # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+ results.append(("org.matrix.signing_key_update", result))
+
return now_stream_id, results
- def _get_devices_by_remote_txn(
+ def _get_device_updates_by_remote_txn(
self, txn, destination, from_stream_id, now_stream_id, limit
):
"""Return device update information for a given remote destination
@@ -200,6 +256,7 @@ class DeviceWorkerStore(SQLBaseStore):
Returns:
List: List of device updates
"""
+ # get the list of device updates that need to be sent
sql = """
SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
@@ -225,12 +282,16 @@ class DeviceWorkerStore(SQLBaseStore):
List[Dict]: List of objects representing an device update EDU
"""
- devices = yield self.runInteraction(
- "_get_e2e_device_keys_txn",
- self._get_e2e_device_keys_txn,
- query_map.keys(),
- include_all_devices=True,
- include_deleted_devices=True,
+ devices = (
+ yield self.runInteraction(
+ "_get_e2e_device_keys_txn",
+ self._get_e2e_device_keys_txn,
+ query_map.keys(),
+ include_all_devices=True,
+ include_deleted_devices=True,
+ )
+ if query_map
+ else {}
)
results = []
@@ -262,7 +323,7 @@ class DeviceWorkerStore(SQLBaseStore):
else:
result["deleted"] = True
- results.append(result)
+ results.append(("m.device_list_update", result))
return results
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index a0bc6f2d18..073412a78d 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -315,6 +315,30 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
from_user_id,
)
+ def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
+ """Return a list of changes from the user signature stream to notify remotes.
+ Note that the user signature stream represents when a user signs their
+ device with their user-signing key, which is not published to other
+ users or servers, so no `destination` is needed in the returned
+ list. However, this is needed to poke workers.
+
+ Args:
+ from_key (int): the stream ID to start at (exclusive)
+ to_key (int): the stream ID to end at (inclusive)
+
+ Returns:
+ Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
+ """
+ sql = """
+ SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id
+ FROM user_signature_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ GROUP BY user_id
+ """
+ return self._execute(
+ "get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key
+ )
+
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index 22025effbc..04ce21ac66 100644
--- a/synapse/storage/data_stores/main/event_push_actions.py
+++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -863,7 +863,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
)
stream_row = txn.fetchone()
if stream_row:
- offset_stream_ordering, = stream_row
+ (offset_stream_ordering,) = stream_row
rotate_to_stream_ordering = min(
self.stream_ordering_day_ago, offset_stream_ordering
)
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index a71d7346d2..68f27078c4 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -30,7 +30,7 @@ from prometheus_client import Counter
from twisted.internet import defer
import synapse.metrics
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventContentFields, EventTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
@@ -933,6 +933,13 @@ class EventsStore(
self._handle_event_relations(txn, event)
+ # Store the labels for this event.
+ labels = event.content.get(EventContentFields.LABELS)
+ if labels:
+ self.insert_labels_for_event_txn(
+ txn, event.event_id, labels, event.room_id, event.depth
+ )
+
# Insert into the room_memberships table.
self._store_room_members_txn(
txn,
@@ -1126,7 +1133,7 @@ class EventsStore(
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return count
ret = yield self.runInteraction("count_messages", _count_messages)
@@ -1147,7 +1154,7 @@ class EventsStore(
"""
txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return count
ret = yield self.runInteraction("count_daily_sent_messages", _count_messages)
@@ -1162,7 +1169,7 @@ class EventsStore(
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return count
ret = yield self.runInteraction("count_daily_active_rooms", _count)
@@ -1596,7 +1603,7 @@ class EventsStore(
""",
(room_id,),
)
- min_depth, = txn.fetchone()
+ (min_depth,) = txn.fetchone()
logger.info("[purge] updating room_depth to %d", min_depth)
@@ -1905,6 +1912,33 @@ class EventsStore(
get_all_updated_current_state_deltas_txn,
)
+ def insert_labels_for_event_txn(
+ self, txn, event_id, labels, room_id, topological_ordering
+ ):
+ """Store the mapping between an event's ID and its labels, with one row per
+ (event_id, label) tuple.
+
+ Args:
+ txn (LoggingTransaction): The transaction to execute.
+ event_id (str): The event's ID.
+ labels (list[str]): A list of text labels.
+ room_id (str): The ID of the room the event was sent to.
+ topological_ordering (int): The position of the event in the room's topology.
+ """
+ return self._simple_insert_many_txn(
+ txn=txn,
+ table="event_labels",
+ values=[
+ {
+ "event_id": event_id,
+ "label": label,
+ "room_id": room_id,
+ "topological_ordering": topological_ordering,
+ }
+ for label in labels
+ ],
+ )
+
AllNewEventsResult = namedtuple(
"AllNewEventsResult",
diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py
index 31ea6f917f..51352b9966 100644
--- a/synapse/storage/data_stores/main/events_bg_updates.py
+++ b/synapse/storage/data_stores/main/events_bg_updates.py
@@ -438,7 +438,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
if not rows:
return 0
- upper_event_id, = rows[-1]
+ (upper_event_id,) = rows[-1]
# Update the redactions with the received_ts.
#
diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py
index aeae5a2b28..b3a2771f1b 100644
--- a/synapse/storage/data_stores/main/group_server.py
+++ b/synapse/storage/data_stores/main/group_server.py
@@ -249,7 +249,7 @@ class GroupServerStore(SQLBaseStore):
WHERE group_id = ? AND category_id = ?
"""
txn.execute(sql, (group_id, category_id))
- order, = txn.fetchone()
+ (order,) = txn.fetchone()
if existing:
to_update = {}
@@ -509,7 +509,7 @@ class GroupServerStore(SQLBaseStore):
WHERE group_id = ? AND role_id = ?
"""
txn.execute(sql, (group_id, role_id))
- order, = txn.fetchone()
+ (order,) = txn.fetchone()
if existing:
to_update = {}
diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py
index e6ee1e4aaa..b41c3d317a 100644
--- a/synapse/storage/data_stores/main/monthly_active_users.py
+++ b/synapse/storage/data_stores/main/monthly_active_users.py
@@ -171,7 +171,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
txn.execute(sql)
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return count
return self.runInteraction("count_users", _count_users)
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index cd95f1ce60..b520062d84 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -143,7 +143,7 @@ class PushRulesWorkerStore(
" WHERE user_id = ? AND ? < stream_id"
)
txn.execute(sql, (user_id, last_id))
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return bool(count)
return self.runInteraction(
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index 6c5b29288a..f70d41ecab 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -459,7 +459,7 @@ class RegistrationWorkerStore(SQLBaseStore):
WHERE appservice_id IS NULL
"""
)
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return count
ret = yield self.runInteraction("count_users", _count_users)
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index bc04bfd7d4..2af24a20b7 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -927,7 +927,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
if not row or not row[0]:
return processed, True
- next_room, = row
+ (next_room,) = row
sql = """
UPDATE current_state_events
diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql
new file mode 100644
index 0000000000..5e29c1da19
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql
@@ -0,0 +1,30 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- room_id and topoligical_ordering are denormalised from the events table in order to
+-- make the index work.
+CREATE TABLE IF NOT EXISTS event_labels (
+ event_id TEXT,
+ label TEXT,
+ room_id TEXT NOT NULL,
+ topological_ordering BIGINT NOT NULL,
+ PRIMARY KEY(event_id, label)
+);
+
+
+-- This index enables an event pagination looking for a particular label to index the
+-- event_labels table first, which is much quicker than scanning the events table and then
+-- filtering by label, if the label is rarely used relative to the size of the room.
+CREATE INDEX event_labels_room_id_label_idx ON event_labels(room_id, label, topological_ordering);
diff --git a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite
new file mode 100644
index 0000000000..e8b1fd35d8
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite
@@ -0,0 +1,42 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* Change the hidden column from a default value of FALSE to a default value of
+ * 0, because sqlite3 prior to 3.23.0 caused the hidden column to contain the
+ * string 'FALSE', which is truthy.
+ *
+ * Since sqlite doesn't allow us to just change the default value, we have to
+ * recreate the table, copy the data, fix the rows that have incorrect data, and
+ * replace the old table with the new table.
+ */
+
+CREATE TABLE IF NOT EXISTS devices2 (
+ user_id TEXT NOT NULL,
+ device_id TEXT NOT NULL,
+ display_name TEXT,
+ last_seen BIGINT,
+ ip TEXT,
+ user_agent TEXT,
+ hidden BOOLEAN DEFAULT 0,
+ CONSTRAINT device_uniqueness UNIQUE (user_id, device_id)
+);
+
+INSERT INTO devices2 SELECT * FROM devices;
+
+UPDATE devices2 SET hidden = 0 WHERE hidden = 'FALSE';
+
+DROP TABLE devices;
+
+ALTER TABLE devices2 RENAME TO devices;
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index a59b8331e1..d1d7c6863d 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -672,7 +672,7 @@ class SearchStore(SearchBackgroundUpdateStore):
)
)
txn.execute(query, (value, search_query))
- headline, = txn.fetchall()[0]
+ (headline,) = txn.fetchall()[0]
# Now we need to pick the possible highlights out of the haedline
# result.
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index bf6de4ca22..e1d3041c7c 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -725,16 +725,18 @@ class StateGroupWorkerStore(
member_filter, non_member_filter = state_filter.get_member_split()
# Now we look them up in the member and non-member caches
- non_member_state, incomplete_groups_nm, = (
- yield self._get_state_for_groups_using_cache(
- groups, self._state_group_cache, state_filter=non_member_filter
- )
+ (
+ non_member_state,
+ incomplete_groups_nm,
+ ) = yield self._get_state_for_groups_using_cache(
+ groups, self._state_group_cache, state_filter=non_member_filter
)
- member_state, incomplete_groups_m, = (
- yield self._get_state_for_groups_using_cache(
- groups, self._state_group_members_cache, state_filter=member_filter
- )
+ (
+ member_state,
+ incomplete_groups_m,
+ ) = yield self._get_state_for_groups_using_cache(
+ groups, self._state_group_members_cache, state_filter=member_filter
)
state = dict(non_member_state)
@@ -1106,7 +1108,7 @@ class StateBackgroundUpdateStore(
" WHERE id < ? AND room_id = ?",
(state_group, room_id),
)
- prev_group, = txn.fetchone()
+ (prev_group,) = txn.fetchone()
new_last_state_group = state_group
if prev_group:
diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py
index 4d59b7833f..45b3de7d56 100644
--- a/synapse/storage/data_stores/main/stats.py
+++ b/synapse/storage/data_stores/main/stats.py
@@ -773,7 +773,7 @@ class StatsStore(StateDeltasStore):
(room_id,),
)
- current_state_events_count, = txn.fetchone()
+ (current_state_events_count,) = txn.fetchone()
users_in_room = self.get_users_in_room_txn(txn, room_id)
@@ -863,7 +863,7 @@ class StatsStore(StateDeltasStore):
""",
(user_id,),
)
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return count, pos
joined_rooms, pos = yield self.runInteraction(
diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py
index 263999dfca..616ef91d4e 100644
--- a/synapse/storage/data_stores/main/stream.py
+++ b/synapse/storage/data_stores/main/stream.py
@@ -229,6 +229,14 @@ def filter_to_clause(event_filter):
clauses.append("contains_url = ?")
args.append(event_filter.contains_url)
+ # We're only applying the "labels" filter on the database query, because applying the
+ # "not_labels" filter via a SQL query is non-trivial. Instead, we let
+ # event_filter.check_fields apply it, which is not as efficient but makes the
+ # implementation simpler.
+ if event_filter.labels:
+ clauses.append("(%s)" % " OR ".join("label = ?" for _ in event_filter.labels))
+ args.extend(event_filter.labels)
+
return " AND ".join(clauses), args
@@ -864,8 +872,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
args.append(int(limit))
sql = (
- "SELECT event_id, topological_ordering, stream_ordering"
+ "SELECT DISTINCT event_id, topological_ordering, stream_ordering"
" FROM events"
+ " LEFT JOIN event_labels USING (event_id, room_id, topological_ordering)"
" WHERE outlier = ? AND room_id = ? AND %(bounds)s"
" ORDER BY topological_ordering %(order)s,"
" stream_ordering %(order)s LIMIT ?"
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index cbb0a4810a..9d851beaa5 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -46,7 +46,7 @@ def _load_current_id(db_conn, table, column, step=1):
cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
else:
cur.execute("SELECT MIN(%s) FROM %s" % (column, table))
- val, = cur.fetchone()
+ (val,) = cur.fetchone()
cur.close()
current_id = int(val) if val else step
return (max if step > 0 else min)(current_id, step)
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 6ba623de13..2dc5052249 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -19,6 +19,7 @@ import jsonschema
from twisted.internet import defer
+from synapse.api.constants import EventContentFields
from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
from synapse.events import FrozenEvent
@@ -95,6 +96,8 @@ class FilteringTestCase(unittest.TestCase):
"types": ["m.room.message"],
"not_rooms": ["!726s6s6q:example.com"],
"not_senders": ["@spam:example.com"],
+ "org.matrix.labels": ["#fun"],
+ "org.matrix.not_labels": ["#work"],
},
"ephemeral": {
"types": ["m.receipt", "m.typing"],
@@ -320,6 +323,46 @@ class FilteringTestCase(unittest.TestCase):
)
self.assertFalse(Filter(definition).check(event))
+ def test_filter_labels(self):
+ definition = {"org.matrix.labels": ["#fun"]}
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown",
+ content={EventContentFields.LABELS: ["#fun"]},
+ )
+
+ self.assertTrue(Filter(definition).check(event))
+
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown",
+ content={EventContentFields.LABELS: ["#notfun"]},
+ )
+
+ self.assertFalse(Filter(definition).check(event))
+
+ def test_filter_not_labels(self):
+ definition = {"org.matrix.not_labels": ["#fun"]}
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown",
+ content={EventContentFields.LABELS: ["#fun"]},
+ )
+
+ self.assertFalse(Filter(definition).check(event))
+
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown",
+ content={EventContentFields.LABELS: ["#notfun"]},
+ )
+
+ self.assertTrue(Filter(definition).check(event))
+
@defer.inlineCallbacks
def test_filter_presence_match(self):
user_filter_json = {"presence": {"types": ["m.*"]}}
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index f360c8e965..5ec568f4e6 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -73,7 +73,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
"get_received_txn_response",
"set_received_txn_response",
"get_destination_retry_timings",
- "get_devices_by_remote",
+ "get_device_updates_by_remote",
# Bits that user_directory needs
"get_user_directory_stream_pos",
"get_current_state_deltas",
@@ -109,7 +109,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
retry_timings_res
)
- self.datastore.get_devices_by_remote.return_value = (0, [])
+ self.datastore.get_device_updates_by_remote.return_value = (0, [])
def get_received_txn_response(*args):
return defer.succeed(None)
diff --git a/tests/http/__init__.py b/tests/http/__init__.py
index 2d5dba6464..2096ba3c91 100644
--- a/tests/http/__init__.py
+++ b/tests/http/__init__.py
@@ -20,6 +20,23 @@ from zope.interface import implementer
from OpenSSL import SSL
from OpenSSL.SSL import Connection
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
+from twisted.internet.ssl import Certificate, trustRootFromCertificates
+from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401
+from twisted.web.iweb import IPolicyForHTTPS # noqa: F401
+
+
+def get_test_https_policy():
+ """Get a test IPolicyForHTTPS which trusts the test CA cert
+
+ Returns:
+ IPolicyForHTTPS
+ """
+ ca_file = get_test_ca_cert_file()
+ with open(ca_file) as stream:
+ content = stream.read()
+ cert = Certificate.loadPEM(content)
+ trust_root = trustRootFromCertificates([cert])
+ return BrowserLikePolicyForHTTPS(trustRoot=trust_root)
def get_test_ca_cert_file():
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 71d7025264..cfcd98ff7d 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -124,19 +124,24 @@ class MatrixFederationAgentTests(unittest.TestCase):
FakeTransport(client_protocol, self.reactor, server_tls_protocol)
)
+ # grab a hold of the TLS connection, in case it gets torn down
+ server_tls_connection = server_tls_protocol._tlsConnection
+
+ # fish the test server back out of the server-side TLS protocol.
+ http_protocol = server_tls_protocol.wrappedProtocol
+
# give the reactor a pump to get the TLS juices flowing.
self.reactor.pump((0.1,))
# check the SNI
- server_name = server_tls_protocol._tlsConnection.get_servername()
+ server_name = server_tls_connection.get_servername()
self.assertEqual(
server_name,
expected_sni,
"Expected SNI %s but got %s" % (expected_sni, server_name),
)
- # fish the test server back out of the server-side TLS protocol.
- return server_tls_protocol.wrappedProtocol
+ return http_protocol
@defer.inlineCallbacks
def _make_get_request(self, uri):
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
new file mode 100644
index 0000000000..22abf76515
--- /dev/null
+++ b/tests/http/test_proxyagent.py
@@ -0,0 +1,334 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+import treq
+
+from twisted.internet import interfaces # noqa: F401
+from twisted.internet.protocol import Factory
+from twisted.protocols.tls import TLSMemoryBIOFactory
+from twisted.web.http import HTTPChannel
+
+from synapse.http.proxyagent import ProxyAgent
+
+from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
+from tests.server import FakeTransport, ThreadedMemoryReactorClock
+from tests.unittest import TestCase
+
+logger = logging.getLogger(__name__)
+
+HTTPFactory = Factory.forProtocol(HTTPChannel)
+
+
+class MatrixFederationAgentTests(TestCase):
+ def setUp(self):
+ self.reactor = ThreadedMemoryReactorClock()
+
+ def _make_connection(
+ self, client_factory, server_factory, ssl=False, expected_sni=None
+ ):
+ """Builds a test server, and completes the outgoing client connection
+
+ Args:
+ client_factory (interfaces.IProtocolFactory): the the factory that the
+ application is trying to use to make the outbound connection. We will
+ invoke it to build the client Protocol
+
+ server_factory (interfaces.IProtocolFactory): a factory to build the
+ server-side protocol
+
+ ssl (bool): If true, we will expect an ssl connection and wrap
+ server_factory with a TLSMemoryBIOFactory
+
+ expected_sni (bytes|None): the expected SNI value
+
+ Returns:
+ IProtocol: the server Protocol returned by server_factory
+ """
+ if ssl:
+ server_factory = _wrap_server_factory_for_tls(server_factory)
+
+ server_protocol = server_factory.buildProtocol(None)
+
+ # now, tell the client protocol factory to build the client protocol,
+ # and wire the output of said protocol up to the server via
+ # a FakeTransport.
+ #
+ # Normally this would be done by the TCP socket code in Twisted, but we are
+ # stubbing that out here.
+ client_protocol = client_factory.buildProtocol(None)
+ client_protocol.makeConnection(
+ FakeTransport(server_protocol, self.reactor, client_protocol)
+ )
+
+ # tell the server protocol to send its stuff back to the client, too
+ server_protocol.makeConnection(
+ FakeTransport(client_protocol, self.reactor, server_protocol)
+ )
+
+ if ssl:
+ http_protocol = server_protocol.wrappedProtocol
+ tls_connection = server_protocol._tlsConnection
+ else:
+ http_protocol = server_protocol
+ tls_connection = None
+
+ # give the reactor a pump to get the TLS juices flowing (if needed)
+ self.reactor.advance(0)
+
+ if expected_sni is not None:
+ server_name = tls_connection.get_servername()
+ self.assertEqual(
+ server_name,
+ expected_sni,
+ "Expected SNI %s but got %s" % (expected_sni, server_name),
+ )
+
+ return http_protocol
+
+ def test_http_request(self):
+ agent = ProxyAgent(self.reactor)
+
+ self.reactor.lookups["test.com"] = "1.2.3.4"
+ d = agent.request(b"GET", b"http://test.com")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 80)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory, _get_test_protocol_factory()
+ )
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+ def test_https_request(self):
+ agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
+
+ self.reactor.lookups["test.com"] = "1.2.3.4"
+ d = agent.request(b"GET", b"https://test.com/abc")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 443)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory,
+ _get_test_protocol_factory(),
+ ssl=True,
+ expected_sni=b"test.com",
+ )
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/abc")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+ def test_http_request_via_proxy(self):
+ agent = ProxyAgent(self.reactor, http_proxy=b"proxy.com:8888")
+
+ self.reactor.lookups["proxy.com"] = "1.2.3.5"
+ d = agent.request(b"GET", b"http://test.com")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.5")
+ self.assertEqual(port, 8888)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory, _get_test_protocol_factory()
+ )
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"http://test.com")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+ def test_https_request_via_proxy(self):
+ agent = ProxyAgent(
+ self.reactor,
+ contextFactory=get_test_https_policy(),
+ https_proxy=b"proxy.com",
+ )
+
+ self.reactor.lookups["proxy.com"] = "1.2.3.5"
+ d = agent.request(b"GET", b"https://test.com/abc")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.5")
+ self.assertEqual(port, 1080)
+
+ # make a test HTTP server, and wire up the client
+ proxy_server = self._make_connection(
+ client_factory, _get_test_protocol_factory()
+ )
+
+ # fish the transports back out so that we can do the old switcheroo
+ s2c_transport = proxy_server.transport
+ client_protocol = s2c_transport.other
+ c2s_transport = client_protocol.transport
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending CONNECT request
+ self.assertEqual(len(proxy_server.requests), 1)
+
+ request = proxy_server.requests[0]
+ self.assertEqual(request.method, b"CONNECT")
+ self.assertEqual(request.path, b"test.com:443")
+
+ # tell the proxy server not to close the connection
+ proxy_server.persistent = True
+
+ # this just stops the http Request trying to do a chunked response
+ # request.setHeader(b"Content-Length", b"0")
+ request.finish()
+
+ # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
+ ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
+ ssl_protocol = ssl_factory.buildProtocol(None)
+ http_server = ssl_protocol.wrappedProtocol
+
+ ssl_protocol.makeConnection(
+ FakeTransport(client_protocol, self.reactor, ssl_protocol)
+ )
+ c2s_transport.other = ssl_protocol
+
+ self.reactor.advance(0)
+
+ server_name = ssl_protocol._tlsConnection.get_servername()
+ expected_sni = b"test.com"
+ self.assertEqual(
+ server_name,
+ expected_sni,
+ "Expected SNI %s but got %s" % (expected_sni, server_name),
+ )
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/abc")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+
+def _wrap_server_factory_for_tls(factory, sanlist=None):
+ """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
+
+ The resultant factory will create a TLS server which presents a certificate
+ signed by our test CA, valid for the domains in `sanlist`
+
+ Args:
+ factory (interfaces.IProtocolFactory): protocol factory to wrap
+ sanlist (iterable[bytes]): list of domains the cert should be valid for
+
+ Returns:
+ interfaces.IProtocolFactory
+ """
+ if sanlist is None:
+ sanlist = [b"DNS:test.com"]
+
+ connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
+ return TLSMemoryBIOFactory(
+ connection_creator, isClient=False, wrappedFactory=factory
+ )
+
+
+def _get_test_protocol_factory():
+ """Get a protocol Factory which will build an HTTPChannel
+
+ Returns:
+ interfaces.IProtocolFactory
+ """
+ server_factory = Factory.forProtocol(HTTPChannel)
+
+ # Request.finish expects the factory to have a 'log' method.
+ server_factory.log = _log_request
+
+ return server_factory
+
+
+def _log_request(request):
+ """Implements Factory.log, which is expected by Request.finish"""
+ logger.info("Completed request %s", request)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 8ce6bb62da..af2327fb66 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -50,7 +50,7 @@ class HTTPPusherTests(HomeserverTestCase):
config = self.default_config()
config["start_pushers"] = True
- hs = self.setup_test_homeserver(config=config, simple_http_client=m)
+ hs = self.setup_test_homeserver(config=config, proxied_http_client=m)
return hs
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 2f2ca74611..5e38fd6ced 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -24,7 +24,7 @@ from six.moves.urllib import parse as urlparse
from twisted.internet import defer
import synapse.rest.admin
-from synapse.api.constants import Membership
+from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.rest.client.v1 import login, profile, room
from tests import unittest
@@ -811,6 +811,105 @@ class RoomMessageListTestCase(RoomBase):
self.assertTrue("chunk" in channel.json_body)
self.assertTrue("end" in channel.json_body)
+ def test_filter_labels(self):
+ """Test that we can filter by a label."""
+ message_filter = json.dumps(
+ {"types": [EventTypes.Message], "org.matrix.labels": ["#fun"]}
+ )
+
+ events = self._test_filter_labels(message_filter)
+
+ self.assertEqual(len(events), 2, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
+ self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
+
+ def test_filter_not_labels(self):
+ """Test that we can filter by the absence of a label."""
+ message_filter = json.dumps(
+ {"types": [EventTypes.Message], "org.matrix.not_labels": ["#fun"]}
+ )
+
+ events = self._test_filter_labels(message_filter)
+
+ self.assertEqual(len(events), 3, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "without label", events[0])
+ self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1])
+ self.assertEqual(
+ events[2]["content"]["body"], "with two wrong labels", events[2]
+ )
+
+ def test_filter_labels_not_labels(self):
+ """Test that we can filter by both a label and the absence of another label."""
+ sync_filter = json.dumps(
+ {
+ "types": [EventTypes.Message],
+ "org.matrix.labels": ["#work"],
+ "org.matrix.not_labels": ["#notfun"],
+ }
+ )
+
+ events = self._test_filter_labels(sync_filter)
+
+ self.assertEqual(len(events), 1, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
+
+ def _test_filter_labels(self, message_filter):
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with right label",
+ EventContentFields.LABELS: ["#fun"],
+ },
+ )
+
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "without label"},
+ )
+
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with wrong label",
+ EventContentFields.LABELS: ["#work"],
+ },
+ )
+
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with two wrong labels",
+ EventContentFields.LABELS: ["#work", "#notfun"],
+ },
+ )
+
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with right label",
+ EventContentFields.LABELS: ["#fun"],
+ },
+ )
+
+ token = "s0_0_0_0_0_0_0_0_0"
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/messages?access_token=x&from=%s&filter=%s"
+ % (self.room_id, token, message_filter),
+ )
+ self.render(request)
+
+ return channel.json_body["chunk"]
+
class RoomSearchTestCase(unittest.HomeserverTestCase):
servlets = [
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index cdded88b7f..8ea0cb05ea 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -106,13 +106,22 @@ class RestHelper(object):
self.auth_user_id = temp_id
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
- if txn_id is None:
- txn_id = "m%s" % (str(time.time()))
if body is None:
body = "body_text_here"
- path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id)
content = {"msgtype": "m.text", "body": body}
+
+ return self.send_event(
+ room_id, "m.room.message", content, txn_id, tok, expect_code
+ )
+
+ def send_event(
+ self, room_id, type, content={}, txn_id=None, tok=None, expect_code=200
+ ):
+ if txn_id is None:
+ txn_id = "m%s" % (str(time.time()))
+
+ path = "/_matrix/client/r0/rooms/%s/send/%s/%s" % (room_id, type, txn_id)
if tok:
path = path + "?access_token=%s" % tok
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 71895094bd..3283c0e47b 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -12,10 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
from mock import Mock
import synapse.rest.admin
+from synapse.api.constants import EventContentFields, EventTypes
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import sync
@@ -26,7 +28,12 @@ from tests.server import TimedOutException
class FilterTestCase(unittest.HomeserverTestCase):
user_id = "@apple:test"
- servlets = [sync.register_servlets]
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
def make_homeserver(self, reactor, clock):
@@ -70,6 +77,140 @@ class FilterTestCase(unittest.HomeserverTestCase):
)
+class SyncFilterTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def test_sync_filter_labels(self):
+ """Test that we can filter by a label."""
+ sync_filter = json.dumps(
+ {
+ "room": {
+ "timeline": {
+ "types": [EventTypes.Message],
+ "org.matrix.labels": ["#fun"],
+ }
+ }
+ }
+ )
+
+ events = self._test_sync_filter_labels(sync_filter)
+
+ self.assertEqual(len(events), 2, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
+ self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
+
+ def test_sync_filter_not_labels(self):
+ """Test that we can filter by the absence of a label."""
+ sync_filter = json.dumps(
+ {
+ "room": {
+ "timeline": {
+ "types": [EventTypes.Message],
+ "org.matrix.not_labels": ["#fun"],
+ }
+ }
+ }
+ )
+
+ events = self._test_sync_filter_labels(sync_filter)
+
+ self.assertEqual(len(events), 3, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "without label", events[0])
+ self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1])
+ self.assertEqual(
+ events[2]["content"]["body"], "with two wrong labels", events[2]
+ )
+
+ def test_sync_filter_labels_not_labels(self):
+ """Test that we can filter by both a label and the absence of another label."""
+ sync_filter = json.dumps(
+ {
+ "room": {
+ "timeline": {
+ "types": [EventTypes.Message],
+ "org.matrix.labels": ["#work"],
+ "org.matrix.not_labels": ["#notfun"],
+ }
+ }
+ }
+ )
+
+ events = self._test_sync_filter_labels(sync_filter)
+
+ self.assertEqual(len(events), 1, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
+
+ def _test_sync_filter_labels(self, sync_filter):
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+
+ room_id = self.helper.create_room_as(user_id, tok=tok)
+
+ self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with right label",
+ EventContentFields.LABELS: ["#fun"],
+ },
+ tok=tok,
+ )
+
+ self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "without label"},
+ tok=tok,
+ )
+
+ self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with wrong label",
+ EventContentFields.LABELS: ["#work"],
+ },
+ tok=tok,
+ )
+
+ self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with two wrong labels",
+ EventContentFields.LABELS: ["#work", "#notfun"],
+ },
+ tok=tok,
+ )
+
+ self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with right label",
+ EventContentFields.LABELS: ["#fun"],
+ },
+ tok=tok,
+ )
+
+ request, channel = self.make_request(
+ "GET", "/sync?filter=%s" % sync_filter, access_token=tok
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ return channel.json_body["rooms"]["join"][room_id]["timeline"]["events"]
+
+
class SyncTypingTests(unittest.HomeserverTestCase):
servlets = [
diff --git a/tests/server.py b/tests/server.py
index 469efb4edb..f878aeaada 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -395,11 +395,24 @@ class FakeTransport(object):
self.disconnecting = True
if self._protocol:
self._protocol.connectionLost(reason)
- self.disconnected = True
+
+ # if we still have data to write, delay until that is done
+ if self.buffer:
+ logger.info(
+ "FakeTransport: Delaying disconnect until buffer is flushed"
+ )
+ else:
+ self.disconnected = True
def abortConnection(self):
logger.info("FakeTransport: abortConnection()")
- self.loseConnection()
+
+ if not self.disconnecting:
+ self.disconnecting = True
+ if self._protocol:
+ self._protocol.connectionLost(None)
+
+ self.disconnected = True
def pauseProducing(self):
if not self.producer:
@@ -430,6 +443,9 @@ class FakeTransport(object):
self._reactor.callLater(0.0, _produce)
def write(self, byt):
+ if self.disconnecting:
+ raise Exception("Writing to disconnecting FakeTransport")
+
self.buffer = self.buffer + byt
# always actually do the write asynchronously. Some protocols (notably the
@@ -474,6 +490,10 @@ class FakeTransport(object):
if self.buffer and self.autoflush:
self._reactor.callLater(0.0, self.flush)
+ if not self.buffer and self.disconnecting:
+ logger.info("FakeTransport: Buffer now empty, completing disconnect")
+ self.disconnected = True
+
def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol:
"""
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 3cc18f9f1c..6f8d990959 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -72,7 +72,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
)
@defer.inlineCallbacks
- def test_get_devices_by_remote(self):
+ def test_get_device_updates_by_remote(self):
device_ids = ["device_id1", "device_id2"]
# Add two device updates with a single stream_id
@@ -81,7 +81,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
)
# Get all device updates ever meant for this remote
- now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+ now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"somehost", -1, limit=100
)
@@ -89,7 +89,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
self._check_devices_in_updates(device_ids, device_updates)
@defer.inlineCallbacks
- def test_get_devices_by_remote_limited(self):
+ def test_get_device_updates_by_remote_limited(self):
# Test breaking the update limit in 1, 101, and 1 device_id segments
# first add one device
@@ -115,20 +115,20 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
#
# first we should get a single update
- now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+ now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"someotherhost", -1, limit=100
)
self._check_devices_in_updates(device_ids1, device_updates)
# Then we should get an empty list back as the 101 devices broke the limit
- now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+ now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"someotherhost", now_stream_id, limit=100
)
self.assertEqual(len(device_updates), 0)
# The 101 devices should've been cleared, so we should now just get one device
# update
- now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+ now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"someotherhost", now_stream_id, limit=100
)
self._check_devices_in_updates(device_ids3, device_updates)
@@ -137,7 +137,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
"""Check that an specific device ids exist in a list of device update EDUs"""
self.assertEqual(len(device_updates), len(expected_device_ids))
- received_device_ids = {update["device_id"] for update in device_updates}
+ received_device_ids = {
+ update["device_id"] for edu_type, update in device_updates
+ }
self.assertEqual(received_device_ids, set(expected_device_ids))
@defer.inlineCallbacks
diff --git a/tests/test_federation.py b/tests/test_federation.py
index d1acb16f30..7d82b58466 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -59,7 +59,9 @@ class MessageAcceptTests(unittest.TestCase):
)
self.handler = self.homeserver.get_handlers().federation_handler
- self.handler.do_auth = lambda *a, **b: succeed(True)
+ self.handler.do_auth = lambda origin, event, context, auth_events: succeed(
+ context
+ )
self.client = self.homeserver.get_federation_client()
self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
pdus
diff --git a/tox.ini b/tox.ini
index 50b6afe611..afe9bc909b 100644
--- a/tox.ini
+++ b/tox.ini
@@ -114,7 +114,7 @@ skip_install = True
basepython = python3.6
deps =
flake8
- black==19.3b0 # We pin so that our tests don't start failing on new releases of black.
+ black==19.10b0 # We pin so that our tests don't start failing on new releases of black.
commands =
python -m black --check --diff .
/bin/sh -c "flake8 synapse tests scripts scripts-dev synctl {env:PEP8SUFFIX:}"
@@ -167,6 +167,6 @@ deps =
env =
MYPYPATH = stubs/
extras = all
-commands = mypy --show-traceback --check-untyped-defs --show-error-codes --follow-imports=normal \
+commands = mypy \
synapse/logging/ \
synapse/config/
|